In this I'll train all the applicable models I've built on the [CIFAR 10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

In [25]:
import pickle
import timeit
import numpy as np
import pandas as pd

In [None]:
from src.knearest import KNearestNeighbors
from src.decision_tree import DecisionTree
from src.neural_network import NeuralNetwork, LabelHandler
from src.pca import PCA
from src.helpers import accuracy

In [43]:
with open('../data/cifar-10-batches-py/data_batch_1', 'rb') as f:
    data = pickle.load(f, encoding='bytes')
with open('../data/cifar-10-batches-py/test_batch', 'rb') as f:
    test_data = pickle.load(f, encoding='bytes')
with open('../data/cifar-10-batches-py/batches.meta', 'rb') as f:
    metadata = pickle.load(f)

In [64]:
X = data[b'data']
y = np.array(data[b'labels'])

test_X = test_data[b'data']
test_y = np.array(test_data[b'labels'])

label_names = np.array(metadata['label_names'])
label_names

array(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
       'horse', 'ship', 'truck'],
      dtype='<U10')

In [47]:
X.shape

(10000, 3072)

From the website: "Each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image."

In [58]:
joined = np.hstack((X, y.reshape(len(y),1)))
np.random.shuffle(joined)
split = int(len(joined) * .75)
train = joined[:split]
cv = joined[split:]
train_X = train[:, range(3072)]
train_y = train[:, -1]
cv_X = cv[:, range(3072)]
cv_y = cv[:, -1]

Split data into train and cross-validation set (75% and 25%)

## k-Nearest Neighbors

In [72]:
knn = KNearestNeighbors()

# No fitting necessary
knn_predictions = knn.predict(cv_X[:10], np.hstack((train_X, train_y.reshape(len(train_y), 1))))

I don't recommend running this one, as it is EXCEPTIONALLY slow.

In [74]:
knn_predictions == cv_y[:10]

array([ True,  True, False,  True, False, False, False,  True, False, False], dtype=bool)

## Decision Tree

In [69]:
dtree = DecisionTree(max_depth=20)

In [71]:
dtree.fit?