In [99]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
import cv2
from PIL import Image as PILImage
import glob
import os
from annoy import AnnoyIndex
import matplotlib.pyplot as plt

import csv
import pandas as pd


In [100]:
IMAGE_SIZE = (256, 256)

In [101]:
def imshow(a, size=1.0):
    # Clip and convert the image to uint8
    a = a.clip(0, 255).astype("uint8")
    
    # Resize the image if a size factor is provided
    if size != 1.0:
        new_dim = (int(a.shape[1] * size), int(a.shape[0] * size))
        a = cv2.resize(a, new_dim, interpolation=cv2.INTER_AREA)
    
    # Display the image
    display(PILImage.fromarray(a))

In [102]:
def resize_crop(image, image_size=(256, 256)):
    target_height, target_width = image_size
    original_height, original_width = image.shape[:2]

    original_aspect = original_width / original_height
    target_aspect = target_width / target_height

    if original_aspect > target_aspect:
        new_width = original_height
        crop_x = (original_width - new_width) // 2
        cropped_image = image[:, crop_x:crop_x + new_width]
    elif original_aspect < target_aspect:
        new_height = original_width
        crop_y = (original_height - new_height) // 2
        cropped_image = image[crop_y:crop_y + new_height, :]
    else:
        cropped_image = image

    resized_image = cv2.resize(cropped_image, (target_width, target_height))
    return resized_image

def preprocess(image):
    image = image / 255.0
    image = resize_crop(image)
    image = cv2.GaussianBlur(image, (5, 5), 0)
    return image

def read_image(image_file):
    image = cv2.imread(image_file)
    image = preprocess(image) 
    return image    

In [103]:
data_folder = "Data_Final/*"
image_files = glob.glob(os.path.join(data_folder, "*.jpg"), recursive=True)

In [104]:
csv_file = "App/src/CSVs/classification_model_embeddings.csv"
#csv_file = "CSVs/classification_model_embeddings.csv"

In [105]:
embeddings_df = pd.read_csv(csv_file)

In [106]:
#embedding_model = load_model("App/src/Models/classification_model.keras")
embedding_model = load_model("Models/embedding_model_julka.keras")

In [107]:
from sklearn.metrics.pairwise import cosine_similarity

def get_most_similar_images(query_image_path, embedding_model, embeddings_df, k=5):
    query_image = read_image(query_image_path)
    query_image = np.expand_dims(query_image, axis=0)
    
    query_embedding = embedding_model.predict(query_image).flatten()
    
    embeddings = embeddings_df.drop(columns=['filename']).values
    filenames = embeddings_df['filename'].values
    
    similarities = cosine_similarity([query_embedding], embeddings)
    
    top_k_indices = similarities[0].argsort()[-k:][::-1]
    
    most_similar_filenames = filenames[top_k_indices]
    
    return most_similar_filenames, similarities[0][top_k_indices]

In [None]:
import matplotlib.pyplot as plt

query_folder = "Queries_Big/*"
query_files = glob.glob(query_folder)

for query in query_files:
    print(query)
    query_img = read_image(query)
    #query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
    
    matches, dists = get_most_similar_images(query, embedding_model, embeddings_df, k=3)
    
    match_1 = read_image("Data_Final/" + matches[0])
    match_2 = read_image("Data_Final/" + matches[1])
    match_3 = read_image("Data_Final/" + matches[2])
    
    #match_1 = cv2.cvtColor(match_1, cv2.COLOR_BGR2RGB)
    #match_2 = cv2.cvtColor(match_2, cv2.COLOR_BGR2RGB)
    #match_3 = cv2.cvtColor(match_3, cv2.COLOR_BGR2RGB)
    
    fig, axes = plt.subplots(1, 4, figsize=(15, 5))
    
    axes[0].imshow(query_img)
    axes[0].set_title("Query")
    axes[0].axis("off")
    
    axes[1].imshow(match_1)
    axes[1].set_title(f"Distance: {dists[0]:.4f}")
    axes[1].axis("off")
    
    axes[2].imshow(match_2)
    axes[2].set_title(f"Distance: {dists[1]:.4f}")
    axes[2].axis("off")
    
    axes[3].imshow(match_3)
    axes[3].set_title(f"Distance: {dists[2]:.4f}")
    axes[3].axis("off")