In [1]:
import os
import turicreate as tc
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [2]:
train_data = tc.image_analysis.load_images("dataset/train/", with_path=True)
len(train_data)

4838

In [3]:
train_data["label"] = train_data["path"].apply(lambda path: os.path.basename(os.path.split(path)[0]))
train_data["label"].summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  4838 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   20  |    No    |
+------------------+-------+----------+

Most frequent items:
+-----------+-------+
|   value   | count |
+-----------+-------+
| pineapple |  260  |
|   apple   |  250  |
|   banana  |  250  |
|  doughnut |  250  |
|   grape   |  250  |
|  hot dog  |  250  |
| ice cream |  250  |
|   juice   |  250  |
|   muffin  |  250  |
|   salad   |  250  |
+-----------+-------+


In [4]:
model = tc.image_classifier.create(train_data, target="label", model="squeezenet_v1.1", verbose=True, max_iterations=160)

PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.
          You can set ``validation_set=None`` to disable validation tracking.



In [5]:
test_data = tc.image_analysis.load_images("dataset/test/", with_path=True)
len(test_data)

952

In [7]:
test_data["label"] = test_data["path"].apply(lambda path: os.path.basename(os.path.split(path)[0]))
test_data["label"].summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  952  |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   20  |    No    |
+------------------+-------+----------+

Most frequent items:
+-----------+-------+
|   value   | count |
+-----------+-------+
|   apple   |   50  |
|   banana  |   50  |
|    cake   |   50  |
|   candy   |   50  |
|   carrot  |   50  |
|   cookie  |   50  |
|  doughnut |   50  |
|   grape   |   50  |
|  hot dog  |   50  |
| ice cream |   50  |
+-----------+-------+


In [8]:
metrics = model.evaluate(test_data)
print("Accuracy: ", metrics["accuracy"])
print("Precision: ", metrics["precision"])
print("Recall: ", metrics["recall"])
print("Confusion Matrix: ", metrics["confusion_matrix"])

Accuracy:  0.6491596638655462
Precision:  0.6482360441779452
Recall:  0.6461326530612246
Confusion Matrix:  +--------------+-----------------+-------+
| target_label | predicted_label | count |
+--------------+-----------------+-------+
|   hot dog    |      salad      |   3   |
|    salad     |      candy      |   1   |
|     cake     |    ice cream    |   2   |
|  pineapple   |      carrot     |   1   |
|    apple     |      grape      |   1   |
|  ice cream   |    ice cream    |   23  |
|     cake     |      waffle     |   1   |
|  strawberry  |       cake      |   2   |
|    salad     |    watermelon   |   1   |
|  strawberry  |      salad      |   1   |
+--------------+-----------------+-------+
[209 rows x 3 columns]
Note: Only the head of the SFrame is printed.
You can use print_rows(num_rows=m, num_columns=n) to print more rows and columns.


In [9]:
#sorting the labels alphanumerically
labels = test_data["label"].unique().sort()

In [12]:
#creating the functions to get better visualization of Confusion using Heatmap.

#function to compute the confusion matrix

def compute_confusion_matrix(metrics, labels):
    number_of_label = len(labels)
    labels_to_index = {l:i for i,l in enumerate(labels)}
    
    confusion_matrix = np.zeros((number_of_label, number_of_label), dtype=np.int)
    for row in metrics["confusion_matrix"]:
        true_label = labels_to_index[row["target_label"]]
        predicted_label = labels_to_index[row["predicted_label"]]
        confusion_matrix[true_label, predicted_label] = row["count"]
    return confusion_matrix

#function to plot the computed confusion matrix to heatmap

def plot_confusion_matrix(conf, labels, figsize=(8,8)):
    fig = plt.figure(figsize=figsize)
    
    heatmap = sns.heatmap(conf, annot=True, fmt="d")
    heatmap.xaxis.set_ticklabels(labels, rotation=45, ha="right", fontsize=12)
    heatmap.yaxis.set_ticklabels(labels, rotation=0, ha="right", fontsize=12)
    
    plt.xlabel('Predicted Labels', fontsize=12)
    plt.ylabel("Predicted Labels", fontsize=12)
    plt.show()

In [13]:
#now passing the parameters for confusion matrix
conf = compute_confusion_matrix(metrics, labels)
plot_confusion_matrix(conf, )

<function __main__.plot_confusion_matrix(conf, labels, figsize=(8, 8))>