In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model

from PIL import Image as PILImage

import cv2
import numpy as np
import glob
import os

import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.utils import Sequence

from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.models import Model

from keras.layers import Layer

from annoy import AnnoyIndex

In [None]:
IMAGE_SIZE = (256, 256)

In [None]:
def imshow(a, size=1.0):
    # Clip and convert the image to uint8
    a = a.clip(0, 255).astype("uint8")
    
    # Resize the image if a size factor is provided
    if size != 1.0:
        new_dim = (int(a.shape[1] * size), int(a.shape[0] * size))
        a = cv2.resize(a, new_dim, interpolation=cv2.INTER_AREA)
    

    # Display the image
    display(PILImage.fromarray(a))

In [None]:
class L2Normalization(Layer):
    def call(self, inputs):
        return tf.math.l2_normalize(inputs, axis=1)

In [None]:
def read_image(file_path):
    img = cv2.imread(file_path)
    img = cv2.resize(img, IMAGE_SIZE)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

In [None]:
data_folder = "Data/*"
image_files = glob.glob(os.path.join(data_folder, "*.jpg"), recursive=True)

In [None]:
#PURE RESNET

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(256, 256, 3))

# Add a Global Average Pooling layer to get 1D embeddings
x = base_model.output
x = GlobalAveragePooling2D()(x)

# Create the model for embedding extraction
model_pure_resnet = Model(inputs=base_model.input, outputs=x)

embedding_dimension = 2048
index_pure_resnet = AnnoyIndex(embedding_dimension, 'euclidean')
index_pure_resnet.load('Annoys/embeddings_resnet_index_big.ann')

In [None]:
#TRIPLET LEARNING RESNET FROZEN WEIGHTS

model_resnet_frozen =load_model("Models/embedding_resnet_model_big.keras",
    custom_objects={"L2Normalization": L2Normalization})

embedding_dimension = 512
index_resnet_frozen = AnnoyIndex(embedding_dimension, 'euclidean')
index_resnet_frozen.load('Annoys/embeddings_index_big.ann')

In [None]:
#TRIPLET LEARNING RESNET TRAINED WEIGHTS

model_resnet_trained =load_model("Models/Embedding_resnet_model_exp_big.keras", 
    custom_objects={"L2Normalization": L2Normalization})

embedding_dimension = 512
index_resnet_trained = AnnoyIndex(embedding_dimension, 'euclidean')
index_resnet_trained.load('Annoys/embeddings_index_resnet_exp_big.ann')

In [None]:
#CUSTOM EMBEDDING MODEL

model_custom = load_model("Models/embedding_custom_big.keras",
    custom_objects={"L2Normalization": L2Normalization})

embedding_dimension = 512
index_custom = AnnoyIndex(embedding_dimension, 'euclidean')
index_custom.load('Annoys/embeddings_index_custom_big.ann')

## BENCHMARK THEM ALL

In [None]:
import matplotlib.pyplot as plt

query_folder = "Queries/*"
query_files = glob.glob(query_folder)

models = [model_pure_resnet, model_resnet_frozen, model_resnet_trained, model_custom]
indexes = [index_pure_resnet, index_resnet_frozen, index_resnet_trained, index_custom]
names = ["Pure ResNet", "ResNet Frozen", "ResNet Trained", "Custom"]

for query in query_files:
    img = read_image(query)
    
    for i in range(len(names)):
        model = models[i]
        index = indexes[i]
        name = names[i]
        
        embedding = model.predict(np.expand_dims(img, axis=0))[0]
        embedding = np.array(embedding)
        
        similar_image_indices, distances = index.get_nns_by_vector(embedding, n=3, include_distances=True)
        
        match_1 = image_files[similar_image_indices[0]]
        match_2 = image_files[similar_image_indices[1]]
        match_3 = image_files[similar_image_indices[2]]
        
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
        
        axes[0].imshow(img)
        axes[0].set_title("Query")
        axes[0].axis("off")
        
        axes[1].imshow(read_image(match_1))
        axes[1].set_title(f"Distance: {distances[0]:.4f}")
        axes[1].axis("off")
        
        axes[2].imshow(read_image(match_2))
        axes[2].set_title(f"Distance: {distances[1]:.4f}")
        axes[2].axis("off")
        
        axes[3].imshow(read_image(match_3))
        axes[3].set_title(f"Distance: {distances[2]:.4f}")
        axes[3].axis("off")
        
        fig.suptitle(name)
        
        plt.tight_layout()
        plt.subplots_adjust(top=0.85)
        plt.show()    