In [1]:
import matplotlib.pyplot as plt
import numpy as np

import pickle

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from sklearn.base import clone

In [2]:
DATASET_PATH = "dataset/"

In [3]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
for i in range(1,6):
    data_batch = unpickle(DATASET_PATH + 'data_batch_'+str(i))

    if i==1:
        X_train = data_batch[b'data']
        y_train = np.asarray(data_batch[b'labels'])
    else: 
        X_train = np.append(X_train, data_batch[b'data'], axis=0)
        y_train = np.append(y_train, np.asarray(data_batch[b'labels']))

shuffle_index = np.random.permutation(len(X_train))
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

In [5]:
precision_list = list()
recall_list = list()
f1_list = list()

skfolds = StratifiedKFold(n_splits=10, random_state=42)

sgd_clf = SGDClassifier(random_state=42)

for train_index, test_index in skfolds.split(X_train, y_train):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = (y_train[train_index])
    X_test_fold = X_train[test_index]
    y_test_fold = (y_train[test_index])
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    
    precision_list.append(precision_score(y_test_fold, y_pred, average=None))
    recall_list.append(recall_score(y_test_fold, y_pred, average=None))
    f1_list.append(f1_score(y_test_fold, y_pred, average=None))

  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


In [6]:
def display_scores(measure, scores):
    print(measure, sum(scores)/float(len(scores)))
    
display_scores("precision", precision_list)
display_scores("recall", recall_list)
display_scores("f1", f1_list)

precision [ 0.35747171  0.64999677  0.23749387  0.33556329  0.3883323   0.33203974
  0.50386818  0.27011323  0.46804054  0.51214358]
recall [ 0.3578  0.1986  0.278   0.022   0.0262  0.2574  0.1178  0.6066  0.2624
  0.0662]
f1 [ 0.28613993  0.27424253  0.14199439  0.03927279  0.04816039  0.18721858
  0.16833997  0.22656736  0.27878188  0.10264405]
