#Building an image retrieval system with deep features


#Fire up GraphLab Create

In [1]:
import graphlab

A newer version of GraphLab Create (v1.7.1) is available! Your current version is v1.6.1.

You can use pip to upgrade the graphlab-create package. For more information see https://dato.com/products/create/upgrade.


#Load the CIFAR-10 dataset

We will use a popular benchmark dataset in computer vision called CIFAR-10.  

(We've reduced the data to just 4 categories = {'cat','bird','automobile','dog'}.)

This dataset is already split into a training set and test set. In this simple retrieval example, there is no notion of "testing", so we will only use the training data.

In [2]:
image_train = graphlab.SFrame('image_train_data/')

[INFO] This non-commercial license of GraphLab Create is assigned to kj6aqr@gmail.com and will expire on October 30, 2016. For commercial licensing options, visit https://dato.com/buy/.

[INFO] Start server at: ipc:///tmp/graphlab_server-10429 - Server binary: /Users/aqua/.graphlab/anaconda/lib/python2.7/site-packages/graphlab/unity_server - Server log: /tmp/graphlab_server_1449391330.log
[INFO] GraphLab Server Version: 1.6.1


#Computing deep features for our images

The two lines below allow us to compute deep features.  This computation takes a little while, so we have already computed them and saved the results as a column in the data you loaded. 

(Note that if you would like to compute such deep features and have a GPU on your machine, you should use the GPU enabled GraphLab Create, which will be significantly faster for this task.)

In [None]:
#deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
#image_train['deep_features'] = deep_learning_model.extract_features(image_train)

In [3]:
image_train.head()

id,image,label,deep_features,image_array
24,Height: 32 Width: 32,bird,"[0.242871761322, 1.09545373917, 0.0, ...","[73.0, 77.0, 58.0, 71.0, 68.0, 50.0, 77.0, 69.0, ..."
33,Height: 32 Width: 32,cat,"[0.525087952614, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7.0, 5.0, 8.0, 7.0, 5.0, 8.0, 5.0, 4.0, 6.0, 7.0, ..."
36,Height: 32 Width: 32,cat,"[0.566015958786, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[169.0, 122.0, 65.0, 131.0, 108.0, 75.0, ..."
70,Height: 32 Width: 32,dog,"[1.12979578972, 0.0, 0.0, 0.778194487095, 0.0, ...","[154.0, 179.0, 152.0, 159.0, 183.0, 157.0, ..."
90,Height: 32 Width: 32,bird,"[1.71786928177, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[216.0, 195.0, 180.0, 201.0, 178.0, 160.0, ..."
97,Height: 32 Width: 32,automobile,"[1.57818555832, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[33.0, 44.0, 27.0, 29.0, 44.0, 31.0, 32.0, 45.0, ..."
107,Height: 32 Width: 32,dog,"[0.0, 0.0, 0.220677852631, 0.0, ...","[97.0, 51.0, 31.0, 104.0, 58.0, 38.0, 107.0, 61.0, ..."
121,Height: 32 Width: 32,bird,"[0.0, 0.23753464222, 0.0, 0.0, 0.0, 0.0, ...","[93.0, 96.0, 88.0, 102.0, 106.0, 97.0, 117.0, ..."
136,Height: 32 Width: 32,automobile,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.5737862587, 0.0, ...","[35.0, 59.0, 53.0, 36.0, 56.0, 56.0, 42.0, 62.0, ..."
138,Height: 32 Width: 32,bird,"[0.658935725689, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[205.0, 193.0, 195.0, 200.0, 187.0, 193.0, ..."


#Train a nearest-neighbors model for retrieving images using deep features

We will now build a simple image retrieval system that finds the nearest neighbors for any image.

In [4]:
knn_model = graphlab.nearest_neighbors.create(image_train,features=['deep_features'],
                                             label='id')

PROGRESS: Starting brute force nearest neighbors model training.


#Use image retrieval model with deep features to find similar images

Let's find similar images to this cat picture.

In [5]:
graphlab.canvas.set_target('ipynb')
cat = image_train[18:19]
cat['image'].show()

In [6]:
knn_model.query(cat)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 50.339ms     |
PROGRESS: | Done         |         | 100         | 349.667ms    |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,384,0.0,1
0,6910,36.9403137951,2
0,39777,38.4634888975,3
0,36870,39.7559623119,4
0,41734,39.7866014148,5


We are going to create a simple function to view the nearest neighbors to save typing:

In [7]:
def get_images_from_ids(query_result):
    return image_train.filter_by(query_result['reference_label'],'id')

In [9]:
cat_neighbors = get_images_from_ids(knn_model.query(cat))

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 35.378ms     |
PROGRESS: | Done         |         | 100         | 280.742ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [10]:
cat_neighbors['image'].show()

Very cool results showing similar cats.

##Finding similar images to a car

In [11]:
car = image_train[8:9]
car['image'].show()

In [12]:
get_images_from_ids(knn_model.query(car))['image'].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 28.468ms     |
PROGRESS: | Done         |         | 100         | 246.381ms    |
PROGRESS: +--------------+---------+-------------+--------------+


#Just for fun, let's create a lambda to find and show nearest neighbor images

In [13]:
show_neighbors = lambda i: get_images_from_ids(knn_model.query(image_train[i:i+1]))['image'].show()

In [14]:
show_neighbors(8)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 34.69ms      |
PROGRESS: | Done         |         | 100         | 276.995ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [15]:
show_neighbors(26)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 22.603ms     |
PROGRESS: | Done         |         | 100         | 271.55ms     |
PROGRESS: +--------------+---------+-------------+--------------+


In [16]:
show_neighbors(1222)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 36.073ms     |
PROGRESS: | Done         |         | 100         | 235.274ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [17]:
show_neighbors(2000)

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 35.961ms     |
PROGRESS: | Done         |         | 100         | 230.927ms    |
PROGRESS: +--------------+---------+-------------+--------------+


## Quiz questions

In [26]:
image_train["label"].sketch_summary()


+------------------+-------+----------+
|       item       | value | is exact |
+------------------+-------+----------+
|      Length      |  2005 |   Yes    |
| # Missing Values |   0   |   Yes    |
| # unique values  |   4   |    No    |
+------------------+-------+----------+

Most frequent items:
+-------+------------+-----+-----+------+
| value | automobile | cat | dog | bird |
+-------+------------+-----+-----+------+
| count |    509     | 509 | 509 | 478  |
+-------+------------+-----+-----+------+


### Q1: Least common category is bird

In [30]:
image_test = graphlab.SFrame('image_test_data/')
image_test[0:1]["image"].show()

In [32]:
knn_model.query(image_test[0:1])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 34.654ms     |
PROGRESS: | Done         |         | 100         | 203.667ms    |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,16289,34.623719208,1
0,45646,36.0068799284,2
0,32139,36.5200813436,3
0,25713,36.7548502521,4
0,331,36.8731228168,5


In [158]:
image_train_cat[image_train_cat["id"]==16289]["image"].show()

In [40]:
knn_model.query(image_test[0:1])[0:1]["reference_label"]

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 22.698ms     |
PROGRESS: | Done         |         | 100         | 193.848ms    |
PROGRESS: +--------------+---------+-------------+--------------+


dtype: int
Rows: 1
[16289]

In [93]:
def get_test_images_from_ids(query_result):
    return image_test.filter_by(query_result['reference_label'],'id')
show_test_neighbors = lambda i: get_images_from_ids(knn_model.query(image_test[i:i+1]))['image'].show()

In [68]:
image_train.filter_by(knn_model.query(image_test[0:1])[0:1]['reference_label'],'id')

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 26.42ms      |
PROGRESS: | Done         |         | 100         | 204.089ms    |
PROGRESS: +--------------+---------+-------------+--------------+


id,image,label,deep_features,image_array
16289,Height: 32 Width: 32,cat,"[0.964287519455, 0.0, 0.0, 0.0, 1.12515509129, ...","[215.0, 219.0, 231.0, 215.0, 219.0, 232.0, ..."


In [67]:
image_train.filter_by(knn_model.query(image_test[0:1])[0:1]['reference_label'],'id')["image"].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.0498753   | 29.673ms     |
PROGRESS: | Done         |         | 100         | 201.498ms    |
PROGRESS: +--------------+---------+-------------+--------------+


In [132]:
image_train_cat=image_train[image_train["label"]=="cat"]
image_train_dog=image_train[image_train["label"]=="dog"]
image_train_automobile=image_train[image_train["label"]=="automobile"]
image_train_bird=image_train[image_train["label"]=="bird"]

In [133]:
image_train_cat["image"].show()

In [134]:
image_test["image"].show()

In [177]:
knn_model_cat = graphlab.nearest_neighbors.create(image_train_cat,features=['deep_features'], label='id')
knn_model_dog = graphlab.nearest_neighbors.create(image_train_dog,features=['deep_features'], label='id')
knn_model_automobile = graphlab.nearest_neighbors.create(image_train_automobile,features=['deep_features'], label='id')
knn_model_bird = graphlab.nearest_neighbors.create(image_train_bird,features=['deep_features'], label='id')

PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.
PROGRESS: Starting brute force nearest neighbors model training.


In [136]:
knn_model_cat.query(image_test[0:1])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 15.055ms     |
PROGRESS: | Done         |         | 100         | 89.818ms     |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,16289,34.623719208,1
0,45646,36.0068799284,2
0,32139,36.5200813436,3
0,25713,36.7548502521,4
0,331,36.8731228168,5


In [164]:
image_train[image_train["id"]==16289]["image"].show()

In [149]:
image_train_cat.filter_by(knn_model_cat.query(image_test[0:1])[0:1]['reference_label'],'id')["image"].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 19.057ms     |
PROGRESS: | Done         |         | 100         | 86.574ms     |
PROGRESS: +--------------+---------+-------------+--------------+


In [125]:
knn_model_dog.query(image_test[0:1])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.1         | 20.069ms     |
PROGRESS: | Done         |         | 100         | 119.711ms    |
PROGRESS: +--------------+---------+-------------+--------------+


query_label,reference_label,distance,rank
0,6122,35.1982843489,1
0,9657,36.3923861996,2
0,4551,36.5853376044,3
0,640,37.0058415107,4
0,6924,37.235687683,5


In [170]:
image_train_dog["image"].show()

In [151]:
sum(knn_model_dog.query(image_test[0:1])["distance"])/5

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 17.895ms     |
PROGRESS: | Done         |         | 100         | 87.153ms     |
PROGRESS: +--------------+---------+-------------+--------------+


37.77071136184157

In [152]:
len(knn_model_dog.query(image_test[0:1])["distance"])

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 16.473ms     |
PROGRESS: | Done         |         | 100         | 91.712ms     |
PROGRESS: +--------------+---------+-------------+--------------+


5

In [155]:
knn_model_dog.query(image_test[0:1])["distance"]

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 19.783ms     |
PROGRESS: | Done         |         | 100         | 86.721ms     |
PROGRESS: +--------------+---------+-------------+--------------+


dtype: float
Rows: 5
[37.464262878423774, 37.56668321685285, 37.60472670789396, 37.70655851529755, 38.511325490739715]

In [150]:
image_train_dog.filter_by(knn_model_dog.query(image_test[0:1])[0:1]['reference_label'],'id')["image"].show()

PROGRESS: Starting pairwise querying.
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 0            | 1       | 0.196464    | 22.776ms     |
PROGRESS: | Done         |         | 100         | 89.721ms     |
PROGRESS: +--------------+---------+-------------+--------------+


### Section 4

In [171]:
image_test_cat=image_test[image_test["label"]=="cat"]
image_test_dog=image_test[image_test["label"]=="dog"]
image_test_automobile=image_test[image_test["label"]=="automobile"]
image_test_bird=image_test[image_test["label"]=="bird"]

In [179]:
dog_cat_neighbors = knn_model_cat.query(image_test_dog, k=1)
dog_dog_neighbors = knn_model_dog.query(image_test_dog, k=1)
dog_automobile_neighbors = knn_model_automobile.query(image_test_dog, k=1)
dog_bird_neighbors = knn_model_bird.query(image_test_dog, k=1)

PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 4
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 127000  | 24.9509     | 356.104ms    |
PROGRESS: | Done         | 509000  | 100         | 388.709ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 4
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 127000  | 24.9509     | 357.69ms     |
PROGRESS: 

In [180]:
dog_distances=graphlab.SFrame({'dog-dog':dog_dog_neighbors["distance"],'dog-cat':dog_cat_neighbors["distance"],'dog-automobile':dog_automobile_neighbors["distance"],'dog-bird':dog_bird_neighbors["distance"]})

In [181]:
dog_distances

dog-automobile,dog-bird,dog-cat,dog-dog
41.9579761457,41.7538647304,36.4196077068,33.4773590373
46.0021331807,41.3382958925,38.8353268874,32.8458495684
42.9462290692,38.6157590853,36.9763410854,35.0397073189
41.6866060048,37.0892269954,34.5750072914,33.9010327697
39.2269664935,38.272288694,34.778824791,37.4849250909
40.5845117698,39.1462089236,35.1171578292,34.945165344
45.1067352961,40.523040106,40.6095830913,39.0957278345
41.3221140974,38.1947918393,39.9036867306,37.7696131032
41.8244654995,40.1567131661,38.0674700168,35.1089144603
45.4976929401,45.5597962603,42.7258732951,43.2422832585


In [182]:
dog_distances[0]['dog-cat']

36.41960770675437

In [183]:
def is_dog_correct(row):
    if(row['dog-dog']>row['dog-automobile']):
        return 0
    elif(row['dog-dog']>row['dog-bird']):
        return 0
    elif(row['dog-dog']>row['dog-cat']):
        return 0
    else:
        return 1

In [190]:
dog_distances.apply(is_dog_correct).sum()

678

In [192]:
len(dog_distances)

1000

In [185]:
cat_cat_neighbors = knn_model_cat.query(image_test_cat, k=1)
cat_dog_neighbors = knn_model_dog.query(image_test_cat, k=1)
cat_automobile_neighbors = knn_model_automobile.query(image_test_cat, k=1)
cat_bird_neighbors = knn_model_bird.query(image_test_cat, k=1)

PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 4
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 127000  | 24.9509     | 339.086ms    |
PROGRESS: | Done         | 509000  | 100         | 381.079ms    |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: Starting blockwise querying.
PROGRESS: max rows per data block: 7668
PROGRESS: number of reference data blocks: 4
PROGRESS: number of query data blocks: 1
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | Query points | # Pairs | % Complete. | Elapsed Time |
PROGRESS: +--------------+---------+-------------+--------------+
PROGRESS: | 1000         | 127000  | 24.9509     | 336.891ms    |
PROGRESS: 

In [186]:
cat_distances=graphlab.SFrame({'cat-dog':cat_dog_neighbors["distance"],'cat-cat':cat_cat_neighbors["distance"],'cat-automobile':cat_automobile_neighbors["distance"],'cat-bird':cat_bird_neighbors["distance"]})

In [187]:
def is_cat_correct(row):
    if(row['cat-cat']>row['cat-automobile']):
        return 0
    elif(row['cat-cat']>row['cat-bird']):
        return 0
    elif(row['cat-cat']>row['cat-dog']):
        return 0
    else:
        return 1

In [189]:
cat_distances.apply(is_cat_correct).sum()

548