In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib notebook
import joblib
from sklearn.model_selection import cross_val_score, cross_val_predict, GridSearchCV

# Import data

In [None]:
train_set = pd.read_csv('datasets/mnist_train.csv')
X_train = train_set.iloc[:,1:].values
y_train = train_set.iloc[:,0].values

# Dimension reduction with PCA
pca = joblib.load('models/pca_transformer.pkl')
X_train_reduced = pca.transform(X_train)

# Import trained models

In [2]:
sgd_clf = joblib.load('models/sgd_clf.pkl')
logit_clf = joblib.load('models/logit_clf.pkl')
knn_clf = joblib.load('models/knn_clf.pkl')
gnb_clf = joblib.load('models/gnb_clf.pkl')
tree_clf = joblib.load('models/tree_clf.pkl')
svc_clf = joblib.load('models/svc_clf.pkl')
rf_clf = joblib.load('models/rf_clf.pkl')
ehv_clf = joblib.load('models/hve_clf.pkl')
esv_clf = joblib.load('models/sve_clf.pkl')

FileNotFoundError: [Errno 2] No such file or directory: 'models/ehv_clf.pkl'

# Compare model's performances

In [None]:
models = [sgd_clf, logit_clf, knn_clf, gnb_clf, tree_clf, svc_clf, rf_clf, ehv_clf, esv_clf]
accs = [sgd_acc, logit_acc, knn_acc, gnb_acc, tree_acc, svc_acc, rf_acc, ehv_acc, esv_acc]
preds = []
for model in models:
    print(model)
    preds.append(cross_val_predict(model,X_train_reduced,y_train,cv=3))

In [None]:
from sklearn.metrics import confusion_matrix

n_models = len(models)
rows = n_models 
cols = 2

fig, axs = plt.subplots(rows, cols, figsize=(2*cols,1.8*rows), constrained_layout=True)

for row in range(rows):
    conf_mx = confusion_matrix(y_train, preds[row])
    row_sums = conf_mx.sum(axis=1, keepdims=True)
    norm_conf_mx = conf_mx / row_sums
    np.fill_diagonal(norm_conf_mx, 0)
    axs[row][0].matshow(conf_mx, cmap=plt.cm.gray)
    axs[row][1].matshow(norm_conf_mx, cmap=plt.cm.gray)
    axs[row][0].set_title(str(models[row])[:10])
    axs[row][1].set_title(f"acc: {100*accs[row]:.2f}%")
    
for ax in axs.flatten():
    ax.axis('off')

In [None]:
plt.close()