In [1]:
import turicreate as tc
import matplotlib.pyplot as plt

In [2]:
train_data = tc.image_analysis.load_images("snacks/train", with_path=True)
len(train_data)

4838

In [3]:
test_data = tc.image_analysis.load_images("snacks/test", with_path=True)
len(test_data)

952

In [4]:
healthy = [
    'apple',
    'banana',
    'carrot',
    'grape',
    'juice',
    'orange',
    'pineapple',
    'salad',
    'strawberry',
    'watermelon',
]

unhealthy = [
    'cake',
    'candy',
    'cookie',
    'doughnut',
    'hot dog',
    'ice cream',
    'muffin',
    'popcorn',
    'pretzel',
    'waffle',
]

import os
train_data["label"] = train_data["path"].apply(lambda path: "healthy"
                         if any("/" + class_name in path for class_name in healthy) 
                         else "unhealthy")
train_data["label"].value_counts().print_rows(num_rows=20)

+-----------+-------+
|   value   | count |
+-----------+-------+
|  healthy  |  2507 |
| unhealthy |  2331 |
+-----------+-------+
[2 rows x 2 columns]



In [5]:
binary_model = tc.image_classifier.create(train_data, target="label", 
                                   model="squeezenet_v1.1",
                                   verbose=True, max_iterations=100)

PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.
          You can set ``validation_set=None`` to disable validation tracking.



In [6]:
test_data["label"] = test_data["path"].apply(lambda path: "healthy"
                         if any("/" + class_name in path for class_name in healthy) 
                         else "unhealthy")
test_data["label"].value_counts().print_rows(num_rows=20)

+-----------+-------+
|   value   | count |
+-----------+-------+
|  healthy  |  489  |
| unhealthy |  463  |
+-----------+-------+
[2 rows x 2 columns]



In [7]:
metrics = binary_model.evaluate(test_data)
print("Accuracy: ", metrics["accuracy"])
print("Precision: ", metrics["precision"])
print("Recall: ", metrics["recall"])

Accuracy:  0.8665966386554622
Precision:  0.8783783783783784
Recall:  0.8423326133909287


In [8]:
multi_model = tc.load_model("MultiSnacks.model")

In [9]:
output = multi_model.classify(test_data)
imgs_with_pred = test_data.add_columns(output)
imgs_with_pred

path,image,label,class,probability
snacks/test/apple/00341c3 c5825fc7e.jpg ...,Height: 256 Width: 256,healthy,apple,0.6897836167016994
snacks/test/apple/004be96 d7985d83e.jpg ...,Height: 256 Width: 384,healthy,apple,0.5662980389419667
snacks/test/apple/01ac2a4 2f2a22ee7.jpg ...,Height: 256 Width: 341,healthy,orange,0.5852423689463567
snacks/test/apple/03bfc0b 1cc6bde63.jpg ...,Height: 256 Width: 384,healthy,orange,0.918169075531706
snacks/test/apple/09ed54b 36eaa5316.jpg ...,Height: 256 Width: 455,healthy,orange,0.373284750216909
snacks/test/apple/0f8670e 41c97c8cb.jpg ...,Height: 256 Width: 361,healthy,apple,0.9886308233209152
snacks/test/apple/137591b b5ac95a5e.jpg ...,Height: 256 Width: 385,healthy,orange,0.5509159595718824
snacks/test/apple/1382c47 d4df56b77.jpg ...,Height: 256 Width: 332,healthy,apple,0.3940126026938378
snacks/test/apple/1acfd56 0a4424e04.jpg ...,Height: 341 Width: 256,healthy,watermelon,0.8942638950206466
snacks/test/apple/1db0cb7 5f37d6cba.jpg ...,Height: 256 Width: 341,healthy,apple,0.979446294141604


In [14]:
# Find images the model predicts are in classes in the healthy array
imgs_healthy_classes = imgs_with_pred.filter_by(healthy, "class")
imgs_healthy_classes.num_rows()

521

In [16]:
imgs_healthy_classes.head()

path,image,label,class,probability
snacks/test/apple/00341c3 c5825fc7e.jpg ...,Height: 256 Width: 256,healthy,apple,0.6897836167016994
snacks/test/apple/004be96 d7985d83e.jpg ...,Height: 256 Width: 384,healthy,apple,0.5662980389419667
snacks/test/apple/01ac2a4 2f2a22ee7.jpg ...,Height: 256 Width: 341,healthy,orange,0.5852423689463567
snacks/test/apple/03bfc0b 1cc6bde63.jpg ...,Height: 256 Width: 384,healthy,orange,0.918169075531706
snacks/test/apple/09ed54b 36eaa5316.jpg ...,Height: 256 Width: 455,healthy,orange,0.373284750216909
snacks/test/apple/0f8670e 41c97c8cb.jpg ...,Height: 256 Width: 361,healthy,apple,0.9886308233209152
snacks/test/apple/137591b b5ac95a5e.jpg ...,Height: 256 Width: 385,healthy,orange,0.5509159595718824
snacks/test/apple/1382c47 d4df56b77.jpg ...,Height: 256 Width: 332,healthy,apple,0.3940126026938378
snacks/test/apple/1acfd56 0a4424e04.jpg ...,Height: 341 Width: 256,healthy,watermelon,0.8942638950206466
snacks/test/apple/1db0cb7 5f37d6cba.jpg ...,Height: 256 Width: 341,healthy,apple,0.979446294141604


In [20]:
# Find images predicted to be in healthy classes whose label is 'unhealthy'
imgs_unhealthy_label = imgs_healthy_classes.filter_by(["unhealthy"], "label")
imgs_unhealthy_label.num_rows()

75

In [21]:
imgs_unhealthy_label.head()

path,image,label,class,probability
snacks/test/cake/260e1012 91d6d34d.jpg ...,Height: 256 Width: 341,unhealthy,pineapple,0.8842416870989014
snacks/test/cake/49464fc6 1be971db.jpg ...,Height: 256 Width: 384,unhealthy,juice,0.4336994420161237
snacks/test/cake/600c36d1 e6f5da87.jpg ...,Height: 256 Width: 341,unhealthy,strawberry,0.7864199598025318
snacks/test/cake/6172064b bb38cafb.jpg ...,Height: 256 Width: 341,unhealthy,juice,0.5075976431874204
snacks/test/cake/7d8937c1 93a599eb.jpg ...,Height: 256 Width: 333,unhealthy,banana,0.355667495561298
snacks/test/cake/8199f71f 6863969b.jpg ...,Height: 256 Width: 385,unhealthy,carrot,0.8500653383760911
snacks/test/cake/88aa2154 914ce3a7.jpg ...,Height: 256 Width: 342,unhealthy,juice,0.740875934570468
snacks/test/cake/97067d14 00252334.jpg ...,Height: 256 Width: 342,unhealthy,apple,0.3113456860337111
snacks/test/cake/bc41ce28 fc883cd5.jpg ...,Height: 256 Width: 341,unhealthy,pineapple,0.5350222560583726
snacks/test/cake/be1f7d58 66cf7930.jpg ...,Height: 256 Width: 256,unhealthy,salad,0.5652848267875529


In [22]:
correct = imgs_healthy_classes.num_rows() - imgs_unhealthy_label.num_rows()
accuracy = correct / imgs_with_pred.num_rows()
accuracy

0.4684873949579832