#Building an image retrieval system with deep features


#Fire up GraphLab Create

In [2]:
import graphlab

In [8]:
graphlab.canvas.set_target('ipynb')

#Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [3]:
image_train = graphlab.SFrame('image_train_data/')

[INFO] This non-commercial license of GraphLab Create is assigned to chengzh2008@gmail.comand will expire on October 15, 2016. For commercial licensing options, visit https://dato.com/buy/.

[INFO] Start server at: ipc:///tmp/graphlab_server-12330 - Server binary: /Users/xiaoyazi/anaconda/lib/python2.7/site-packages/graphlab/unity_server - Server log: /tmp/graphlab_server_1445104991.log
[INFO] GraphLab Server Version: 1.6.1


#Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [4]:
image_train['label'].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+


In [5]:
dog_image_train = image_train[image_train['label'] == 'dog']
cat_image_train = image_train[image_train['label'] == 'cat']
bird_image_train = image_train[image_train['label'] == 'bird']
car_image_train = image_train[image_train['label'] == 'automobile']

In [6]:
dog_knn_model = graphlab.nearest_neighbors.create(dog_image_train,features=['deep_features'],
                                             label='id')

PROGRESS: Starting brute force nearest neighbors model training.


In [11]:
cat_knn_model = graphlab.nearest_neighbors.create(cat_image_train, features=['deep_features'], label='id')

PROGRESS: Starting brute force nearest neighbors model training.


In [12]:
bird_knn_model = graphlab.nearest_neighbors.create(bird_image_train, features=['deep_features'], label='id')
car_knn_model = graphlab.nearest_neighbors.create(car_image_train, features=['deep_features'], label='id')

PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.


In [13]:
image_test = graphlab.SFrame('image_test_data/')

In [22]:
cat_5neighbors = cat_knn_model.query(image_test[0:1])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 9.709ms      |
PROGRESS: | Done         |         | 100         | 44.644ms     |
PROGRESS: +--------------+---------+-------------+--------------+


In [23]:
dog_5neighbors = dog_knn_model.query(image_test[0:1])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 9.932ms      |
PROGRESS: | Done         |         | 100         | 42.833ms     |
PROGRESS: +--------------+---------+-------------+--------------+


In [69]:
cat_5neighbors.head()
cat_5neighbors.show()


In [66]:
cat_image_train[cat_image_train['id'] == 16289]['image'].show()

In [70]:
dog_5neighbors.show()

In [68]:
dog_image_train[dog_image_train['id'] == 16976]['image'].show()

# split test data into 4 categories

In [28]:
cat_image_test = image_test[image_test['label'] == 'cat']
dog_image_test = image_test[image_test['label'] == 'dog']
bird_image_test = image_test[image_test['label'] == 'bird']
car_image_test = image_test[image_test['label'] == 'automobile']

In [32]:
dog_cat_neighbors = cat_knn_model.query(dog_image_test, k=1)
dog_bird_neighbors = bird_knn_model.query(dog_image_test, k=1)
dog_car_neighbors = car_knn_model.query(dog_image_test, k=1)
dog_dog_neighbors = dog_knn_model.query(dog_image_test, k=1)

PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 8
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 64000   | 12.5737     | 258.02ms     |
PROGRESS: | Done         | 509000  | 100         | 288.12ms     |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 8
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 60000   | 12.5523     | 257.106ms    |
PROGRESS: 

In [34]:
dog_car_neighbors.head()

query_label,reference_label,distance,rank
0,33859,41.9579761457,1
1,2046,46.0021331807,1
2,19594,42.9462290692,1
3,11000,41.6866060048,1
4,19594,39.2269664935,1
5,49314,40.5845117698,1
6,40822,45.1067352961,1
7,44997,41.3221140974,1
8,33859,41.8244654995,1
9,33859,45.4976929401,1


# create a new SFrame

In [35]:
dog_distances = graphlab.SFrame({'dog-car': dog_car_neighbors['distance'], 'dog-bird': dog_bird_neighbors['distance'], 'dog-cat': dog_car_neighbors['distance'], 'dog-dog': dog_dog_neighbors['distance']})

In [38]:
dog_distances.head()

dog-bird,dog-car,dog-cat,dog-dog
41.7538647304,41.9579761457,41.9579761457,33.4773590373
41.3382958925,46.0021331807,46.0021331807,32.8458495684
38.6157590853,42.9462290692,42.9462290692,35.0397073189
37.0892269954,41.6866060048,41.6866060048,33.9010327697
38.272288694,39.2269664935,39.2269664935,37.4849250909
39.1462089236,40.5845117698,40.5845117698,34.945165344
40.523040106,45.1067352961,45.1067352961,39.0957278345
38.1947918393,41.3221140974,41.3221140974,37.7696131032
40.1567131661,41.8244654995,41.8244654995,35.1089144603
45.5597962603,45.4976929401,45.4976929401,43.2422832585


In [45]:
def is_dog_correct(row):
    return row['dog-dog'] < row['dog-cat'] and row['dog-dog'] < row['dog-car'] and row['dog-dog'] < row['dog-bird'] 
dog_distances['predict'] = dog_distances.apply(is_dog_correct)

In [57]:
dog_distances['predict'].sketch_summary()


+--------------------+----------------+----------+
|        item        |     value      | is exact |
+--------------------+----------------+----------+
|       Length       |      1000      |   Yes    |
|        Min         |      0.0       |   Yes    |
|        Max         |      1.0       |   Yes    |
|        Mean        |     0.879      |   Yes    |
|        Sum         |     879.0      |   Yes    |
|      Variance      |    0.106359    |   Yes    |
| Standard Deviation | 0.326127275768 |   Yes    |
|  # Missing Values  |       0        |   Yes    |
|  # unique values   |       2        |    No    |
+--------------------+----------------+----------+

Most frequent items:
+-------+-----+-----+
| value |  1  |  0  |
+-------+-----+-----+
| count | 879 | 121 |
+-------+-----+-----+

Quantiles: 
+-----+-----+-----+-----+-----+-----+-----+-----+------+
|  0% |  1% |  5% | 25% | 50% | 75% | 95% | 99% | 100% |
+-----+-----+-----+-----+-----+-----+-----+-----+------+
| 0.0 | 0.0 | 0.0 | 

In [71]:
dog_distances['predict'].sum()
dog_distances.show()

In [48]:
cat_cat_neighbors = cat_knn_model.query(cat_image_test, k=1)
cat_bird_neighbors = bird_knn_model.query(cat_image_test, k=1)
cat_car_neighbors = car_knn_model.query(cat_image_test, k=1)
cat_dog_neighbors = dog_knn_model.query(cat_image_test, k=1)

PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 8
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 63000   | 12.3772     | 257.575ms    |
PROGRESS: | Done         | 509000  | 100         | 300.636ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 8
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 60000   | 12.5523     | 255.668ms    |
PROGRESS: 

In [58]:
cat_distances = graphlab.SFrame({'cat-car': cat_car_neighbors['distance'], 'cat-bird': cat_bird_neighbors['distance'], 'cat-cat': cat_cat_neighbors['distance'], 'cat-dog': cat_dog_neighbors['distance']})

In [59]:
cat_distances.head()

cat-bird,cat-car,cat-cat,cat-dog
38.074265869,39.6710582792,34.623719208,37.4642628784
36.3674024138,43.0089056688,33.8680579302,29.3472319585
35.3039394947,38.6010006604,32.4615168902,32.2599640475
38.8944029601,39.3566307091,35.7708210254,35.3852085188
34.2820409875,38.3572372618,31.1577686417,30.0442985088
44.5352170178,42.0904793181,41.3986035847,35.4741000424
34.0290595084,39.0520251253,30.9894594959,32.5845275226
39.0236924983,39.3058645069,37.0814607387,37.6502852614
40.8334054297,43.0248129799,39.9883863688,36.9801353512
40.1258835601,45.6749176426,39.7076633097,41.1259410707


In [60]:
def is_cat_correct(row):
    return row['cat-cat'] < row['cat-dog'] and row['cat-cat'] < row['cat-car'] and row['cat-cat'] < row['cat-bird'] 
cat_distances['predict'] = cat_distances.apply(is_cat_correct)

In [61]:
cat_distances['predict'].sum()

548

In [None]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [3]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


#Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [4]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

PROGRESS: Starting brute force nearest neighbors model training.


#Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [5]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()

In [7]:
knn_model.query(cat)

PROGRESS: Starting pairwise querying...
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 7.97ms       |
PROGRESS: | Done         |         | 100         | 62.996ms     |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,384,0.0,1
0,6910,36.9403137951,2
0,39777,38.4634888975,3
0,36870,39.7559623119,4
0,41734,39.7866014148,5


We are going to create a simple function to view the nearest neighbors to save typing:

In [9]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [10]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

PROGRESS: Starting pairwise querying...
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 7.302ms      |
PROGRESS: | Done         |         | 100         | 69.934ms     |
PROGRESS: +--------------+---------+-------------+--------------+


In [11]:
cat_neighbors['image'].show()

Very cool results showing similar cats.

##Finding similar images to a car

In [12]:
car = image_train[8:9]
car['image'].show()

In [13]:
get_images_from_ids(knn_model.query(car))['image'].show()

PROGRESS: Starting pairwise querying...
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 26.12ms      |
PROGRESS: | Done         |         | 100         | 70.419ms     |
PROGRESS: +--------------+---------+-------------+--------------+


#Just for fun, let's create a lambda to find and show nearest neighbor images

In [14]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [15]:
show_neighbors(8)

PROGRESS: Starting pairwise querying...
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 16.302ms     |
PROGRESS: | Done         |         | 100         | 72.25ms      |
PROGRESS: +--------------+---------+-------------+--------------+


In [16]:
show_neighbors(26)

PROGRESS: Starting pairwise querying...
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 14.465ms     |
PROGRESS: | Done         |         | 100         | 69.576ms     |
PROGRESS: +--------------+---------+-------------+--------------+
