## Test Notebook ( PLEASE DOWNLOAD OUR MODEL FROM THE README! )

In [None]:
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('bmh')
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models
from sklearn.model_selection import train_test_split
from keras.applications.xception import preprocess_input
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

In [None]:
model = models.load_model('my_model.hdf5')
classes = ['Stadium', 'Building/s', 'Traffic Sign', 
           'Forest', 'Flower/s', 'Street', 
           'Classroom', 'Bridge', 'Statue', 'Lake']

In [None]:
# LOAD IN YOUR TEST DATA / LABELS HERE
data_test = np.load('YOUR-DATA.npy') # Data loaded in D x N
labels_test = np.load('YOUR-DATA.npy')

In [None]:
def test_model(test_data, test_labels):
    ''' 
        This function take in test data and test labels to be passed into our model.
    
        It outputs the accuracy score, the predicted labels, and a confusion matrix.
        
        NOTE: We use labels 0-9 instead of 1-10 for the confusion matrix.
    
    '''
    
    # Convert integer encoded labels to 0-9
    new_labels = test_labels - 1
    
    # Preprocess data
    data_test_preprocessed = preprocess_input(test_data.T.reshape(-1,300,300,3)) # Do not transpose if data is N x D
    
    # Predict labels from input data
    preds = model.predict(data_test_preprocessed)
    
    # Convert predictions back to integer encoding
    int_preds = np.argmax(preds,axis=1)
    
    num_correct = 0 # Initialize counter for accuracy
    
    # Compute accuracy metric
    for i in range(len(new_labels)):
        if (int_preds[i] == new_labels[i]):
            num_correct += 1
    
    print('Accuracy: ', num_correct / len(new_labels))
    print('\nPredicted Labels: ', preds)
    
    # Plot confusion matrix
    print('Confusion Matrix: ')
    conf = pd.DataFrame(confusion_matrix(new_labels,int_preds),
                        index= [i for i in "0123456789"], 
                        columns=[i for i in "0123456789"]
                       )
    plt.figure(figsize=(8,8))
    sns.heatmap(conf, annot=True,cbar=False,fmt="d")

In [None]:
test_model(data_test, labels_test)