In [1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import pickle

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

from sklearn.base import clone

In [2]:
DATASET_PATH = "dataset/"

In [3]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
# X_train = np.ndarray(shape=(1000,3072))
# print(X_train.shape)
# y_train = np.array([])


for i in range(1,6):
    data_batch = unpickle(DATASET_PATH + 'data_batch_'+str(i))

    if i==1:
        X_train = data_batch[b'data']
        y_train = np.asarray(data_batch[b'labels'])
    else: 
        X_train = np.append(X_train, data_batch[b'data'], axis=0)
        y_train = np.append(y_train, np.asarray(data_batch[b'labels']))

# shuffle_index = np.random.permutation(9000)
# X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

In [5]:
precision_list = list()
recall_list = list()
f1_list = list()

skfolds = StratifiedKFold(n_splits=3, random_state=42)

sgd_clf = SGDClassifier(random_state=42)

for train_index, test_index in skfolds.split(X_train, y_train):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = (y_train[train_index])
    X_test_fold = X_train[test_index]
    y_test_fold = (y_train[test_index])
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    
    precision_list.append(precision_score(y_test_fold, y_pred, average=None))
    recall_list.append(recall_score(y_test_fold, y_pred, average=None))
    f1_list.append(f1_score(y_test_fold, y_pred, average=None))

In [6]:
def display_scores(measure, scores):
    print(measure, sum(scores)/float(len(scores)))
    
display_scores("precision", precision_list)
display_scores("recall", recall_list)
display_scores("f1", f1_list)

precision [ 0.32333872  0.67629835  0.29987558  0.24319235  0.38456332  0.07928618
  0.27522313  0.41473659  0.55113511  0.52588269]
recall [ 0.05760841  0.11118641  0.0144      0.14065518  0.02799872  0.5824835
  0.46706385  0.12877485  0.03359568  0.12960097]
f1 [ 0.09713944  0.15870399  0.02691202  0.08250934  0.04940836  0.13881346
  0.2269813   0.1241773   0.06174058  0.20409159]
