## Starting with importing packages

In [8]:
import tensorflow
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.preprocessing import  image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors
import pickle
from matplotlib import pyplot as plt


Now Building the resnet model

In [9]:
model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
model.trainable = False

model = tensorflow.keras.Sequential([
    model,
    tensorflow.keras.layers.GlobalAveragePooling2D(),
])

print(model.summary())

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d_2  (None, 2048)              0         
  (GlobalAveragePooling2D)                                       
                                                                 
Total params: 23587712 (89.98 MB)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 23587712 (89.98 MB)
_________________________________________________________________
None


In [10]:
def extract_features(image_path, model):
    img = image.load_img(image_path, target_size=(224,224))
    img_arr = image.img_to_array(img)
    expanded_img_arr = np.expand_dims(img_arr, axis=0)
    preprocess_img = preprocess_input(expanded_img_arr)
    result = model.predict(preprocess_img, verbose=0).flatten()
    normalized_result = result / np.linalg.norm(result)
    return normalized_result

In [11]:
filenames = []

path = "fashion-dataset/images"
for fn in os.listdir(path):
    filenames.append(os.path.join(path,fn))


pickle.dump(filenames, open("filenames.pkl", "wb"))

In [12]:
feature_list = []
for image_path in tqdm(filenames):
    feature_list.append(extract_features(image_path, model))

pickle.dump(feature_list, open('feature_list.pkl', 'wb'))

100%|██████████| 44441/44441 [2:41:53<00:00,  4.58it/s]  


In [13]:
# Unpickle 
filenames = pickle.load(open("filenames.pkl", "rb"))
feature_list = pickle.load(open("feature_list.pkl", "rb"))

In [14]:
def inference(image_path,model, algorithm='brute', metric='euclidean'):
    normalized_result = extract_features(image_path, model)
    neighbours = NearestNeighbors(n_neighbors=5, algorithm=algorithm, metric=metric)
    neighbours.fit(feature_list)
    distances, indices = neighbours.kneighbors([normalized_result])
    return indices[0], distances[0]  