In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import LinearSVC
from sklearn import tree
from sklearn.metrics import accuracy_score
from sklearn.decomposition import PCA
import pandas as pd
import matplotlib.pyplot as plt
import altair as alt

from plotter import plotter
from time import time
    
TRAIN_SIZE = 50000
PCA_MAX    = 30

## Prepare testing & training datasets

In [None]:
# Divide MNIST dataset into pixels and and labels
mnist = pd.read_csv("data/mnist.csv")
mnist_data, mnist_labels = mnist.drop(["label"], axis=1).values, mnist["label"]

In [None]:
train = {
    'data':  mnist_data[:TRAIN_SIZE],
    'label': mnist_labels[:TRAIN_SIZE]
}

test = {
    'data':  mnist_data[TRAIN_SIZE:],
    'label': mnist_labels[TRAIN_SIZE:]
}

scaler = StandardScaler().fit(train['data'])
train['data'] = scaler.transform(train['data'])
test['data'] =  scaler.transform(test['data'])

## Test classifiers

In [None]:
clf_test = []
classifiers = [(RandomForestClassifier(), "RandomForest"), (LinearSVC(), "LinearSVC"),
               (KNeighborsClassifier(), "KNN"), (GaussianNB(), "GausianNB"),
               (tree.DecisionTreeClassifier(), "DecisionTree")]

train_ds = {
    'data':  train['data'],
    'label': train['label']
}

test_ds = {
    'data':  test['data'],
    'label': test['label']
}

for principal_axis in range(1, PCA_MAX+1):
    pca = PCA(principal_axis)
    pca.fit(train['data'])
    train_ds['data'] = pca.transform(train['data'])
    test_ds['data'] = pca.transform(test['data'])
    for clf, name in classifiers:
        clf.fit(train_ds['data'], train_ds['label'])
        start = time()
        score = clf.score(test_ds['data'], test_ds['label'])
        clf_test += [{'label': name, 'x': principal_axis, 'accuracy': score, 'time': time() - start}]

## Compare <ins>Accuracy</ins>

In [9]:
alt.renderers.enable('default'); plotter(clf_test, "accuracy", ("Number of Components", "Time in seconds", "PCAs"))

## Compare <ins>Execution Time</ins>

In [10]:
alt.renderers.enable('default'); plotter(clf_test, "time", ("Number of Components", "Time in seconds", "PCAs"))