In [2]:
%load_ext autoreload
%autoreload 2

from datasets import DataCifar10
from models import KNN
from optimers import Optimer
from tuners import ParamTuner
from utils import check_accuracy

import numpy as np

In [3]:
# load data
dataloader = DataCifar10('./datasets/cifar-10-batches-py', 
                         num_val=1000, num_train=5000, num_test=1000)
dataloader.show_info()

Training data shape:  (5000, 3, 32, 32)
Training labels shape:  (5000,)
Validating data shape:  (1000, 3, 32, 32)
Validating labels shape:  (1000,)
Testing data shape:  (1000, 3, 32, 32)
Testing labels shape:  (1000,)


In [3]:
# train model
model = KNN(hyperparams={'K': 3})
optimer = Optimer()

optimer.train(model, dataloader)

In [4]:
# check accuracy
scores = model.predict(dataloader.x_test)
accuracy = check_accuracy(scores, dataloader.y_test)
    
print('The accuracy on testing dataset is ', accuracy)

The accuracy on testing dataset is  0.26


In [36]:
# tune hyperparameters
tuner = ParamTuner(KNN, Optimer, dataloader)
model_best, param_best, acc_best = tuner.tune({'K': 3}, {'K': [i for i in range(2, 20)]})


Tune K in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
With {'K': 2} accuracy: 0.231  - Best!
With {'K': 3} accuracy: 0.258  - Best!
With {'K': 4} accuracy: 0.275  - Best!
With {'K': 5} accuracy: 0.285  - Best!
With {'K': 6} accuracy: 0.292  - Best!
With {'K': 7} accuracy: 0.293  - Best!
With {'K': 8} accuracy: 0.282
With {'K': 9} accuracy: 0.287
With {'K': 10} accuracy: 0.291
With {'K': 11} accuracy: 0.281
With {'K': 12} accuracy: 0.284
With {'K': 13} accuracy: 0.288
With {'K': 14} accuracy: 0.289
With {'K': 15} accuracy: 0.292
With {'K': 16} accuracy: 0.282
With {'K': 17} accuracy: 0.28
With {'K': 18} accuracy: 0.289
With {'K': 19} accuracy: 0.29


In [39]:
# check accuracy
scores = model_best.predict(dataloader.x_test)
accuracy = check_accuracy(scores, dataloader.y_test)
    
print('The best hyperparameter is ', param_best)
print('The best accuracy on validating dataset is ', acc_best)
print('The accuracy on testing dataset is ', accuracy)

The best hyperparameter is  {'K': 7}
The best accuracy on validating dataset is  0.293
The accuracy on testing dataset is  0.262
