In [2]:
import pandas as pd
import warnings 
warnings.filterwarnings(action='ignore')
import numpy as np
import re
import nltk
from sklearn.model_selection import train_test_split
from nltk import word_tokenize
from wordcloud import WordCloud, STOPWORDS
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# 파일 불러오기

train = pd.read_csv('./open/train.csv', encoding='utf-8')

In [4]:
X = train.loc[:, 'text']
y = train.loc[:, 'author']

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=13)

### tfidfvectorizer

In [6]:
from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer

# TF-IDF Vectorization 적용하여 학습 데이터셋과 테스트 데이터 셋 변환.
tfidf_vect = TfidfVectorizer(stop_words='english')
tfidf_vect.fit(X_train)
X_train_tfidf_vect = tfidf_vect.transform(X_train)
X_test_tfidf_vect = tfidf_vect.transform(X_test)

### SGDClassifier

In [18]:
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score

sgd_clf = SGDClassifier(random_state=13, loss='modified_huber')
sgd_clf.fit(X_train_tfidf_vect, y_train)

train_pred = sgd_clf.predict(X_train_tfidf_vect)
test_pred = sgd_clf.predict(X_test_tfidf_vect)

print('SGDClassifier train accuracy score:', accuracy_score(y_train, train_pred))
print('SGDClassifier test accuracy score:', accuracy_score(y_test, test_pred))

SGDClassifier train accuracy score: 0.8592351319955356
SGDClassifier test accuracy score: 0.7332361516034985


In [9]:
# 교차 검증

from sklearn.model_selection import cross_val_score, cross_validate, cross_val_predict
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=13)
cross_val_score(sgd_clf, X_train_tfidf_vect, y_train, cv=5, scoring='accuracy')

array([0.73397107, 0.73727366, 0.72679649, 0.73405467, 0.73075171])

In [22]:
sgd_clf = SGDClassifier(alpha=0.0001, loss='modified_huber', n_jobs=-1, random_state=13, penalty='l2')
cross_validate(sgd_clf, X_train_tfidf_vect, y_train, scoring=None, cv=skf, return_train_score=True)

{'fit_time': array([0.21991968, 0.2159369 , 0.11206222, 0.21130967, 0.11145711]),
 'score_time': array([0.00423908, 0.00444603, 0.00335908, 0.00521803, 0.00276184]),
 'test_score': array([0.72599932, 0.73351554, 0.73738754, 0.72790433, 0.73132118]),
 'train_score': array([0.87486476, 0.87660156, 0.87372587, 0.87466902, 0.87586482])}

In [23]:
sgd_clf = SGDClassifier(alpha=0.0001, loss='modified_huber', n_jobs=-1, random_state=13, penalty='l1')
cross_validate(sgd_clf, X_train_tfidf_vect, y_train, scoring=None, cv=skf, return_train_score=True)

{'fit_time': array([0.42394495, 0.31523204, 0.32226706, 0.32350588, 0.32055092]),
 'score_time': array([0.00490928, 0.0047822 , 0.00541806, 0.00539422, 0.00475407]),
 'test_score': array([0.67589113, 0.68169912, 0.6766883 , 0.67585421, 0.68189066]),
 'train_score': array([0.76057742, 0.76302602, 0.76333922, 0.76468411, 0.764798  ])}

In [10]:
y_scores = cross_val_predict(sgd_clf, X_train_tfidf_vect, y_train, cv=skf, method='decision_function')

In [11]:
y_scores

array([[ 0.08845517, -1.08782035, -1.39350539, -1.12419603, -0.34179337],
       [-0.17019869, -0.21915934, -0.25056359, -1.78923817, -1.08142273],
       [-0.69963142,  0.42057308, -1.00189905, -1.20256542, -1.11992828],
       ...,
       [-1.384899  , -1.45381016, -1.45764502, -1.09673068,  1.45954107],
       [-0.60249989, -1.09574016, -1.12078779,  0.63297133, -1.33737796],
       [-1.07331235, -1.0934612 , -1.0314122 ,  0.72254814, -0.95285561]])

In [15]:
from sklearn.multiclass import OneVsRestClassifier

ovr_clf = OneVsRestClassifier(SGDClassifier())
ovr_clf.fit(X_train_tfidf_vect, y_train)
ovr_clf.predict(X_train_tfidf_vect)

array([0, 0, 1, ..., 4, 3, 3])

In [19]:
sgd_clf.decision_function(X_train_tfidf_vect)

array([[ 0.00828216, -1.12094925, -1.4737611 , -1.12316962, -0.21027795],
       [-0.40528682,  0.02691771, -0.51867997, -1.56894528, -1.05199848],
       [-0.76149698,  0.52339667, -1.15703551, -1.09581385, -1.0103404 ],
       ...,
       [-1.33518992, -1.43957265, -1.52641655, -0.92931811,  1.38966945],
       [-0.5760874 , -1.05988333, -1.1350919 ,  0.71427824, -1.34419834],
       [-0.99451862, -1.12181347, -1.02081457,  0.83276161, -1.00751358]])