<a href="https://colab.research.google.com/github/pertvirt/hello_world/blob/master/perceptual_image_similarity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import os
import glob
from pathlib import Path
import pickle


import keras
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import lycon
from lshash.lshash import LSHash
from sklearn.preprocessing import normalize

keras.__version__

'2.2.5'

In [0]:
!pip install lshash3

In [0]:
!unzip data

In [8]:
DATA_PATH = Path('/content/data/')

glob.glob(str(DATA_PATH) + '/*')[:5]

['/content/data/octopus',
 '/content/data/ant',
 '/content/data/lotus',
 '/content/data/crab',
 '/content/data/rhino']

## Extracting the feature vectors from VGG16 (pretrained on ImageNet)

In [9]:
batch_size = 16

data_generator = ImageDataGenerator(rescale=1./255,)
net = keras.applications.VGG16(include_top=False, weights='imagenet')

generator = data_generator.flow_from_directory(DATA_PATH,
                                               target_size=(224, 224),
                                               batch_size=batch_size,
                                               class_mode=None,
                                               shuffle=False)






Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


Found 605 images belonging to 11 classes.


In [0]:
#feature_vectors = net.predict_generator(generator, len(filenames) // batch_size, verbose=1) # use for inference (can take some time)
#filenames = generator.filenames

feature_vectors = np.load('./caltech_images/vgg16_features/features.vgg16.npy')
with open(r"./caltech_images/vgg16_features/filenames.pickle", "rb") as input_file:
    feature_filenames = pickle.load(input_file)

In [0]:
feature_vectors.shape
feature_vectors = feature_vectors.reshape((feature_vectors.shape[0], -1))
feature_vectors = normalize(feature_vectors, axis=1, norm='l2') # normalizing every feature vector
feature_vectors.shape

(9088, 25088)

In [0]:
features_table = dict(zip(feature_filenames, feature_vectors))

## Generating the Hashtable (lshash)

In [0]:
hash_index = LSHash(hash_size=16, input_dim=feature_vectors.shape[-1], num_hashtables=8)

In [0]:
for filename, feature_vec in tqdm_notebook(features_table.items()):
    hash_index.index(feature_vec, extra_data=filename)

HBox(children=(IntProgress(value=0, max=9088), HTML(value='')))




## Running queries for similar images

In [0]:
def similar_images(index, results=10):
    if isinstance(index, int):
        response = hash_index.query(features_table[list(features_table.keys())[index]], 
                         num_results=11, distance_func='cosine')
    else:
        response = hash_index.query(features_table[index], 
                         num_results=11, distance_func='cosine')
    images = list()
    for i in range(1, 12):
        img = lycon.load(str(DATA_PATH / response[i-1][0][1]))
        images.append(str(DATA_PATH / response[i-1][0][1]))
    imagesList = ''.join( [f"<img style='height: 120px; margin: 0px; float: left; border: 1px solid black;' src='{s}' />"
                     for s in images ])
    display(HTML(imagesList))

In [0]:
similar_images('laptop/image_0027.jpg')

In [0]:
similar_images('strawberry/image_0027.jpg')

In [0]:
similar_images('water_lilly/image_0009.jpg')