In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.layers import GlobalMaxPooling2D, Input

import pickle

from sklearn.neighbors import NearestNeighbors
from PIL import Image

In [2]:
#Dataset from kaggle https://www.kaggle.com/datasets/vikashrajluhaniwal/fashion-images

In [3]:
feature_model = tf.keras.applications.resnet50.ResNet50(weights='imagenet', include_top=False,
                                                       input_shape=(224,224,3))
feature_model.trainable = False

preprocess_layer = tf.keras.applications.resnet50.preprocess_input

inputs = Input(shape=(224,224,3))
x = preprocess_layer(inputs)
x = feature_model(x)
output = GlobalMaxPooling2D()(x)

model = tf.keras.Model(inputs, output)

Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



In [4]:
def pillow_read(image_path, target_size=(224, 224)):
    img = Image.open(image_path)
    img = img.resize(target_size)
    img = np.array(img.getdata()).reshape(img.size[0], img.size[1], 3)
    return img

In [5]:
def extract(img_path, model):
    img = pillow_read(img_path)
    img = np.expand_dims(img, 0)
    pred = model.predict(img, verbose=0).flatten()
    return pred/np.linalg.norm(pred)

In [8]:
def recommend(image_path, vectors_path, filenames_path, top_k=5):
    result = extract(image_path, model)

    f_vectors = pickle.load(open(vectors_path, 'rb'))
    f_vectors = np.array(f_vectors)
    file_names = pickle.load(open(filenames_path, 'rb'))

    neighbours = NearestNeighbors(n_neighbors=top_k, algorithm='brute', metric='euclidean')
    neighbours.fit(f_vectors)

    distance, indices = neighbours.kneighbors([result])

    nearest_images = [file_names[i] for i in indices[0][1:]]
    original_image = image_path

    return [original_image] + nearest_images

def display_images(images_path):
    original = pillow.Image.open(images_path[0])
    images = images_path[1:]
    images = [pillow_read(f'Dataset/{image}') for image in images]

    plt.figure(figsize=(4,4))
    plt.imshow(original)
    plt.title('Original Image', fontsize=10)
    plt.axis('off')
    # plt.show()
    
    fig, ax = plt.subplots(1, len(images), figsize=(17, 5))
    plt.suptitle('Recommendations', fontsize=15)
    for i, image in enumerate(images):
        ax[i].imshow(image)
        ax[i].axis('off')
    plt.show()


In [9]:
image_path = '5405.jpg'
vectors_path = 'f_vectors.pkl'
filenames_path = 'filenames.pkl'

recommendations = recommend(image_path, vectors_path, filenames_path, top_k=5)
recommendations
# display_images(recommendations)


['5405.jpg', '5402.jpg', '5404.jpg', '5403.jpg', '1831.jpg']