In [13]:
# import keras dependencies 
from keras.models import Model
from keras.applications import MobileNet as CNN
from keras.applications.mobilenet import preprocess_input
from keras.layers import Flatten
from keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np
from IPython.display import Image 
from PIL import Image
import matplotlib.pyplot as plt
import time

#### Load model pre-trained on `ImageNet` dataset

In [10]:
model = CNN(include_top=True)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 225, 225, 3)       0         
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 32)      864       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32)      128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 112, 112, 32)      0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       
__________

The above network loads the `MobileNet` model trained on _ImageNet_. We will try and adapt this architecture which is generally effective in classifying images towards exploring features in our de-duplicaion database.
#### Slice network to limit output features
We can now eyeball the function of all network layers in the model. To limit the number of our output features we cut the network off at `Layer(-3)`.

In [11]:
x = Flatten()(model.layers[-3].output)
model_sliced = Model(inputs=model.input, outputs=[x])
# model_sliced.summary()

In [14]:
img_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
TARGET_SIZE = (224, 224)
BATCH_SIZE = 64

IMG_DIR_QUERY = '/Users/zubin.john/forge/image-dedup/Transformed_dataset/query_for_cnn'
IMG_DIR_RETRIEVALS = '/Users/zubin.john/forge/image-dedup/Transformed_dataset/retrievals_for_cnn'

t1 = time.time()

img_batches_rets = img_gen.flow_from_directory(
    directory=IMG_DIR_RETRIEVALS,
    target_size=TARGET_SIZE,
    batch_size=BATCH_SIZE,
    color_mode='rgb',
    shuffle=False
)

img_batches_query = img_gen.flow_from_directory(
    directory=IMG_DIR_QUERY,
    target_size=TARGET_SIZE,
    batch_size=BATCH_SIZE,
    color_mode='rgb',
    shuffle=False
)

feat_vecs_rets = model_sliced.predict_generator(img_batches_rets, len(img_batches_rets), verbose=1)
feat_vecs_query = model_sliced.predict_generator(img_batches_query, len(img_batches_query), verbose=1)

print(time.time() - t1)

Found 12750 images belonging to 1 classes.
Found 2550 images belonging to 1 classes.
1337.1259410381317


In [15]:
file_mapping_rets = dict(zip(range(len(img_batches_rets.filenames)), img_batches_rets.filenames))
file_mapping_query = dict(zip(range(len(img_batches_query.filenames)), img_batches_query.filenames))

In [18]:
def get_normalized_matrix(x):
    x_norm_per_row = np.linalg.norm(x, axis=1)
    x_norm_per_row = x_norm_per_row[:, np.newaxis] # adding another axis
    x_norm_per_row_tiled = np.tile(x_norm_per_row, (1, x.shape[1]))
    x_normalized = x/x_norm_per_row_tiled 
    return x_normalized

feat_vecs_query_norm = get_normalized_matrix(feat_vecs_query)
feat_vecs_rets_norm = get_normalized_matrix(feat_vecs_rets)

print(time.time() - t1)

2172.8780562877655


In [19]:
dist_vec = np.dot(feat_vecs_query_norm, feat_vecs_rets_norm.T)
print(time.time() - t1)
dist_vec.shape

2195.6226239204407


In [22]:
def get_matches_above_threshold(row, thresh):
    valid_inds = np.where(row >= thresh)[0]
    valid_vals = row[valid_inds]
    return valid_inds, valid_vals


dict_ret = {}

for i in range(dist_vec.shape[0]):
    valid_inds, valid_vals = get_matches_above_threshold(dist_vec[i, :], 0.83)
    retrieved_files = [file_mapping_rets[j] for j in valid_inds]
    query_name = file_mapping_query[i]
    dict_ret[query_name] = dict(zip(retrieved_files, valid_vals))

print(time.time() - t1)

2333.9959831237793


In [23]:
dict_ret['Query/ukbench00147.jpg']

{'Retrieval/ukbench00147_cropped.jpg': 0.9684294,
 'Retrieval/ukbench00147_hflip.jpg': 0.97780776,
 'Retrieval/ukbench00147_resize.jpg': 0.9834271,
 'Retrieval/ukbench00147_rotation.jpg': 0.9295622,
 'Retrieval/ukbench00147_vflip.jpg': 0.94430196}

#### Save all model artifacts

In [29]:
with open('var_1r.pkl', 'wb') as wb:
    pickle.dump(feat_vecs_rets, wb)

with open('var_1q.pkl', 'wb') as wb:
    pickle.dump(feat_vecs_query, wb)
    
with open('docs_1r.pkl', 'wb') as wb:
    pickle.dump(file_mapping_rets, wb)

with open('docs_1q.pkl', 'wb') as wb:
    pickle.dump(file_mapping_query, wb)

### Another variant

In [32]:
y = Flatten()(model.layers[-9].output)
model_sliced_2 = Model(inputs=model.input, outputs=[y])
model_sliced_2.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 225, 225, 3)       0         
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 32)      864       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32)      128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 112, 112, 32)      0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       
__________

In [33]:
feat_vecs_rets_2 = model_sliced_2.predict_generator(img_batches_rets, len(img_batches_rets), verbose=1)
feat_vecs_query_2 = model_sliced_2.predict_generator(img_batches_query, len(img_batches_query), verbose=1)



In [35]:
file_mapping_rets = dict(zip(range(len(img_batches_rets.filenames)), img_batches_rets.filenames))
file_mapping_query = dict(zip(range(len(img_batches_query.filenames)), img_batches_query.filenames))

feat_vecs_query_norm_2 = get_normalized_matrix(feat_vecs_query_2)
feat_vecs_rets_norm_2 = get_normalized_matrix(feat_vecs_rets_2)

#dist_vec = np.dot(feat_vecs_query_norm_2, feat_vecs_rets_norm_2.T)

KeyboardInterrupt: 

#### Save all model artifacts

In [36]:
with open('var_2r.pkl', 'wb') as wb:
    pickle.dump(feat_vecs_rets, wb)

with open('var_2q.pkl', 'wb') as wb:
    pickle.dump(feat_vecs_query, wb)
    
with open('docs_2r.pkl', 'wb') as wb:
    pickle.dump(file_mapping_rets, wb)

with open('docs_2q.pkl', 'wb') as wb:
    pickle.dump(file_mapping_query, wb)