# Building an image retrieval system with deep features


# Fire up GraphLab Create
(See [Getting Started with SFrames](../Week%201/Getting%20Started%20with%20SFrames.ipynb) for setup instructions)

In [1]:
import graphlab

In [2]:
# Limit number of worker processes. This preserves system memory, which prevents hosted notebooks from crashing.
graphlab.set_runtime_config('GRAPHLAB_DEFAULT_NUM_PYLAMBDA_WORKERS', 4)

[INFO] graphlab.cython.cy_server: GraphLab Create v2.1 started. Logging: /tmp/graphlab_server_1527233241.log


This non-commercial license of GraphLab Create for academic use is assigned to js133@rice.edu and will expire on October 21, 2018.


# Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [3]:
image_train = graphlab.SFrame('image_train_data/')

# Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [None]:
# deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
# image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [4]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


In [11]:
image_train[image_train['id'] == 24]['deep_features']

dtype: array
Rows: ?
[array('d', [0.24287176132202148, 1.0954537391662598, 0.0, 0.39362990856170654, 0.0, 0.0, 11.894915580749512, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5788496136665344, 0.4954667389392853, 2.5141289234161377, 0.0, 1.5180106163024902, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5813961029052734, 0.0, 0.0, 2.595609426498413, 2.7079553604125977, 0.0, 0.0, 0.0, 0.8509902954101562, 0.0, 0.7203489542007446, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2700355052947998, 0.0, 0.0, 0.0, 0.0, 0.08592796325683594, 0.0, 0.7010231018066406, 0.0, 0.0, 0.0, 0.0, 0.024805665016174316, 0.0, 0.0, 0.17549043893814087, 0.0, 0.0, 0.0, 0.0, 0.0, 2.392784595489502, 0.0, 0.0, 4.471865653991699, 0.0, 1.6358323097229004, 0.0, 4.417484760284424, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.4117904901504517, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1247677206993103, 0.0, 0.0, 0.8957164287567139, 0.0, 0.0, 0.3334987759590149, 0.0, 0.0, 0.20787304639816284

# Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [12]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

# Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [13]:
cat = image_train[18:19]
cat['image'].show()

Canvas is accessible via web browser at the URL: http://localhost:63664/index.html
Opening Canvas in default web browser.


In [14]:
knn_model.query(cat)

query_label,reference_label,distance,rank
0,384,0.0,1
0,6910,36.9403137951,2
0,39777,38.4634888975,3
0,36870,39.7559623119,4
0,41734,39.7866014148,5


We are going to create a simple function to view the nearest neighbors to save typing:

In [15]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [16]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

In [17]:
cat_neighbors['image'].show()

Canvas is updated and available in a tab in the default browser.


Very cool results showing similar cats.

## Finding similar images to a car

In [18]:
car = image_train[8:9]
car['image'].show()

Canvas is updated and available in a tab in the default browser.


In [19]:
get_images_from_ids(knn_model.query(car))['image'].show()

Canvas is updated and available in a tab in the default browser.


# Just for fun, let's create a lambda to find and show nearest neighbor images

In [20]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [21]:
show_neighbors(8)

Canvas is updated and available in a tab in the default browser.


In [22]:
show_neighbors(26)

Canvas is updated and available in a tab in the default browser.


In [24]:
image_train['label'].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+


In [110]:
len(image_train[image_train['label'] == 'dog'])


509

In [111]:
len(image_train[image_train['label'] == 'cat'])

509

In [112]:
len(image_train[image_train['label'] == 'bird'])

478

In [113]:
len(image_train[image_train['label'] == 'automobile'])

509

In [62]:
dog_model = graphlab.nearest_neighbors.create(image_train[image_train['label'] == 'dog'],features=['deep_features'],
                                             label='id')

In [63]:
cat_model = graphlab.nearest_neighbors.create(image_train[image_train['label'] == 'cat'],features=['deep_features'],
                                             label='id')

In [64]:
bird_model = graphlab.nearest_neighbors.create(image_train[image_train['label'] == 'bird'],features=['deep_features'],
                                             label='id')

In [65]:
car_model = graphlab.nearest_neighbors.create(image_train[image_train['label'] == 'automobile'],features=['deep_features'],
                                             label='id')

In [35]:
test_data = graphlab.SFrame('image_test_data/')
test_data[0:1].show()

Canvas is accessible via web browser at the URL: http://localhost:63664/index.html
Opening Canvas in default web browser.


In [125]:
get_images_from_ids(dog_model.query(test_data[0:1], k = 1))['image'].show()

Canvas is updated and available in a tab in the default browser.


In [124]:
get_images_from_ids(cat_model.query(test_data[0:1], k = 1))['image'].show()

Canvas is updated and available in a tab in the default browser.


In [123]:
get_images_from_ids(car_model.query(test_data[0:1], k = 1))['image'].show()

Canvas is updated and available in a tab in the default browser.


In [68]:
dog_nhbrs = dog_model.query(test_data[0:1])
dog_nhbrs['distance'].mean()

37.77071136184156

In [126]:
cat_nhbrs = cat_model.query(test_data[0:1])
cat_nhbrs['distance'].mean()

36.15573070978294

In [70]:
car_test = test_data[test_data['label'] == 'automobile']
cat_test = test_data[test_data['label'] == 'cat']
dog_test = test_data[test_data['label'] == 'dog']
bird_test = test_data[test_data['label'] == 'bird']

In [71]:
dog_cat_neighbors = cat_model.query(dog_test, k=1)

In [73]:
dog_bird_neighbors = bird_model.query(dog_test, k=1)

In [75]:
dog_car_neighbors = car_model.query(dog_test, k=1)

In [76]:
dog_dog_neighbors = dog_model.query(dog_test, k=1)

In [77]:
dog_distances = graphlab.SFrame({'dog-car': dog_car_neighbors['distance'], 'dog-bird': dog_bird_neighbors['distance'],
                                'dog-cat': dog_cat_neighbors['distance'], 'dog-dog': dog_dog_neighbors['distance']})

In [78]:
dog_distances.head()

dog-bird,dog-car,dog-cat,dog-dog
41.7538647304,41.9579761457,36.4196077068,33.4773590373
41.3382958925,46.0021331807,38.8353268874,32.8458495684
38.6157590853,42.9462290692,36.9763410854,35.0397073189
37.0892269954,41.6866060048,34.5750072914,33.9010327697
38.272288694,39.2269664935,34.778824791,37.4849250909
39.1462089236,40.5845117698,35.1171578292,34.945165344
40.523040106,45.1067352961,40.6095830913,39.0957278345
38.1947918393,41.3221140974,39.9036867306,37.7696131032
40.1567131661,41.8244654995,38.0674700168,35.1089144603
45.5597962603,45.4976929401,42.7258732951,43.2422832585


In [101]:
def is_dog_correct(row):
    ans = 1
    for val in row:
        if row[val] < row['dog-dog']:
            ans = 0
            break
    return ans

In [102]:
dog_distances.apply(is_dog_correct).sum()

678

In [108]:
len(dog_distances)

1000