In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
torch.__version__

'1.8.1+cu102'

In [4]:
import numpy as np
import torch
import torch.distributions as td
import matplotlib.pyplot as plt
import torch.nn as nn

from collections import namedtuple, OrderedDict, defaultdict
from tqdm.auto import tqdm
from itertools import chain
from tabulate import tabulate

In [5]:
import sys
sys.path.append("../")

In [11]:
from data import load_mnist, Batcher

In [12]:
train_loader, valid_loader, test_loader = load_mnist(
    100, 
    save_to='../tmp', 
    height=28, 
    width=28
)

In [13]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report

In [14]:
def get_batcher(data_loader):
    batcher = Batcher(
        data_loader, 
        height=28, 
        width=28, 
        device=torch.device('cpu'), 
        binarize=True, 
        num_classes=10,
        onehot=False
    )
    return batcher

In [15]:
training = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(train_loader)])
training.shape

(55000, 785)

In [16]:
valid = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(valid_loader)])
valid.shape

(5000, 785)

In [17]:
test = np.concatenate([np.concatenate([x_obs.reshape(-1, 28*28).numpy(), c_obs.unsqueeze(1).numpy()], -1) for x_obs, c_obs in get_batcher(test_loader)])
test.shape

(10000, 785)

In [None]:
for alg in ['ball_tree', 'kd_tree']:
    for k in [1, 5, 10]:

        model = KNeighborsClassifier(n_neighbors=k, algorithm=alg, p=2, n_jobs=10)
        model.fit(training[:,:-1], training[:,-1])

        # evaluate the model and update the accuracies list
        score = model.score(valid[:,:-1], valid[:,-1])
        print(f"alg={alg} k={k}, accuracy={score*100:.2f}")

In [18]:
model = KNeighborsClassifier(n_neighbors=5, algorithm='kd_tree', n_jobs=10)
model.fit(training[:,:-1], training[:,-1])

KNeighborsClassifier(algorithm='kd_tree', n_jobs=10)

In [20]:
val_pred = model.predict(valid[:,:-1])
print(classification_report(valid[:,-1], val_pred))

              precision    recall  f1-score   support

         0.0       0.98      0.99      0.98       489
         1.0       0.91      1.00      0.95       530
         2.0       0.98      0.96      0.97       493
         3.0       0.96      0.96      0.96       509
         4.0       0.98      0.95      0.96       499
         5.0       0.95      0.95      0.95       458
         6.0       0.98      0.99      0.98       482
         7.0       0.96      0.98      0.97       563
         8.0       0.99      0.89      0.94       494
         9.0       0.95      0.94      0.95       483

    accuracy                           0.96      5000
   macro avg       0.96      0.96      0.96      5000
weighted avg       0.96      0.96      0.96      5000



In [19]:
predictions = model.predict(test[:,:-1])
print(classification_report(test[:,-1],predictions))

              precision    recall  f1-score   support

         0.0       0.96      0.99      0.98       980
         1.0       0.89      1.00      0.94      1135
         2.0       0.99      0.93      0.96      1032
         3.0       0.95      0.96      0.95      1010
         4.0       0.96      0.93      0.95       982
         5.0       0.94      0.95      0.94       892
         6.0       0.97      0.98      0.98       958
         7.0       0.93      0.95      0.94      1028
         8.0       0.99      0.87      0.92       974
         9.0       0.94      0.93      0.93      1009

    accuracy                           0.95     10000
   macro avg       0.95      0.95      0.95     10000
weighted avg       0.95      0.95      0.95     10000



In [21]:
import pickle

# save the model to disk
pickle.dump(model, open('knnclassifier.pickle', 'wb'))

In [22]:
# load the model from disk
loaded_model = pickle.load(open('knnclassifier.pickle', 'rb'))
result = loaded_model.score(valid[:,:-1], valid[:,-1])
print(result)

0.9628
