### Training Alexnet benchmark model

I have chosen to use a pre-trained Alexnet (trained on ImageNet) as the benchmark model for this project. The model which was trained with the following hyperparameters: 4 epochs with a learning rate of 2e-2, and a batchsize of 6.

In [None]:
# magics
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# imports
from fastai.vision import * # deeplearning library for computer computer vision
import os # navigates operating system

In [None]:
# set data path
path = Path('../input/timenet/tn_data/tN_data')
path

In [None]:
# set random seed 
np.random.seed(42)

# set data source for training files and set aside 20 percent for the validation set
src = (ImageList.from_folder(path/'train').split_by_rand_pct(0.2).label_from_folder())
src

### Data augmentation?

In [None]:
# set transformations for data augmentation
tfms = get_transforms(do_flip=False, flip_vert=False, max_rotate=0, max_zoom=1.2, max_lighting=0.15, max_warp=0, p_affine=0, p_lighting=0.25)

# create dataloader with batchsize=6 transform data to half-original image-size (112 x 112px)
# normalize using imagenet stats because model was pre-trained on ImageNet
data = (src.transform(tfms, size=244).databunch(bs=6).normalize(imagenet_stats))

In [None]:
# set pre-trained model
arch = models.alexnet

# init learner
learn = cnn_learner(data, arch, metrics=accuracy, model_dir='../../../../working') 
# note: model_dir='...' line enables saving models in kaggle notebook , remove if not working with kaggle notebook

In [None]:
# train model for 4 epochs
learn.fit(4, lr=2e-2, wd=0.)

In [None]:
learn.recorder.plot_losses()

In [None]:
# add test set
test = (path/'test/').ls()
data.add_test(test)
learn.data = data

# make predictions on test set
preds, _ = learn.get_preds(ds_type=DatasetType.Test)
labels = np.argmax(preds, 1)

# read test csv with file_names and labels
test_csv = pd.read_csv('../input/timenet-test-labels/test.csv')

In [None]:
# helper functions

# test csv is a file that contains the mappings of test files to true classes
test_csv = pd.read_csv('../input/timenet-test-labels/test.csv')
# test folder path
test_path = path/'test'

def test_accuracy(test_path=test_path, csv_df=test_csv):
    test_res = {}
    error_list = []
    error_paths = []
    for file in (test_path).ls():
        # turn path object into filename string
        fname = str(file).rsplit('/', 1)[-1]
        # open image and make prediction
        img = open_image(file)
        # add normalization fn here
        pred = learn.predict(img)
        # get label
        pred_label = str(pred[0])
        # get results from test_csv file
        actual = str(csv_df[csv_df['file'].str.match(fname)]['class'])
        actual = actual.split('\n')[0].split(' ')[-1]
        # add results to dictionary
        test_res.update({fname: [actual, pred_label, actual==pred_label]})
    # if value is false, add file to list    
    for k, v in test_res.items():
        if v[2] == False:
            error_list.append(k)
    # convert file to file path        
    for idx, i in enumerate(error_list):
        img = error_list[idx]
        error_paths.append(str(path)+'/test/'+img)

    # test if img paths and errors match
    test_accuracy = (1 - (len(error_list)/len(test_res)))*100

    # test if length of list matches length of path list
    if len(error_list) == len(error_paths):
        print('number or errors: {}/{} \ntest accuracy: {}'.format(len(error_list), len(test_res), test_accuracy))
    
    # returns paths of misclassified files(list), file names (list) and test results (dict) 
    return error_paths, error_list, test_res

In [None]:
# A function to plot an image grid of errors 
"""
input: list of error img paths
output: plots grid depicting model errors on test set with information on actual result, 
model prediction and (bool) correct: True/False

"""
def plot_errors(img_list):
    n_errors = len(img_list)
    plt.figure(figsize=(26,26))
    plt.subplots_adjust(hspace=0.3)
    
    # Plot query image
    for i, img in enumerate(img_list):
        ax = plt.subplot(6, 6,i+1)
        ax.axis('off')
        img_name = img.rsplit('/',1)[-1]
        ax.set_title(img_name+'\n act, pred, correct \n'+str(test_res[img_name]))
        
        img = PIL.Image.open(img)
        im = ax.imshow(img)
    
    plt.show() 

#### Checking benchmark performance on test set

In [None]:
# run model on test set
error_paths, error_list, test_res = test_accuracy()

In [None]:
# plot errors 
plot_errors(error_paths)