# Photo Classifier: Model evaluation

## Imports

In [1]:
import random
import numpy as np
import pandas as pd
from collections import Counter
from ipywidgets import interact
import ipywidgets as widgets

In [2]:
from file_io import load_pickle_file
from constants import MODEL_FILE_PATH, IMAGE_DATA_PATH
from image_transforms import IMAGE_TRANSFORMS
from interpreter import Interpreter

In [3]:
import random
random.seed(42)

In [4]:
DATA_SET = 'test'

## Load model and data

In [5]:
model = load_pickle_file(MODEL_FILE_PATH)
image_data = load_pickle_file(IMAGE_DATA_PATH)

## Prediction

In [6]:
x = image_data.get_images(DATA_SET)
y_true = image_data.get_classes(DATA_SET)

In [7]:
y_pred, probabilities = model.predict(x, IMAGE_TRANSFORMS[DATA_SET])

In [8]:
# Accuracy calculation
from sklearn.preprocessing import MultiLabelBinarizer
one_hot_encoder = MultiLabelBinarizer()
y_pred_one_hot_encoded = one_hot_encoder.fit_transform(y_pred)
y_true_one_hot_encoded = one_hot_encoder.fit_transform(y_true)
(y_pred_one_hot_encoded == y_true_one_hot_encoded).mean()

0.8666666666666667

## Results interpretation

In [9]:
interpreter = Interpreter(x, y_pred, y_true, probabilities, model.class_to_label_mapping)

TypeError: unhashable type: 'list'

In [None]:
accuracy = interpreter.calculate_accuracy()
accuracy

In [None]:
confusion_matrix = interpreter.calculate_confusion_matrix()
confusion_matrix

In [None]:
accuracy_by_label = interpreter.calculate_accuracy_by_label()
accuracy_by_label

In [None]:
misclassified_samples = interpreter.get_misclassified_samples()
misclassified_samples

In [None]:
most_uncertain_samples = interpreter.get_most_uncertain_samples(5)
most_uncertain_samples

In [None]:
most_incorrect_samples = interpreter.get_most_incorrect_samples(5)
most_incorrect_samples

In [None]:
interpreter.plot_most_incorrect_samples(5)

In [None]:
interpreter.plot_most_uncertain_samples(5)

In [None]:
@interact(index=widgets.IntSlider(min=0, max=len(y_true)-1, step=1, value=0, continuous_update=False))
def show_prediction(index=0):
    interpreter.plot_prediction(index)