In [61]:
from sklearn.cluster import KMeans
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import Dataset
import os
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, models
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity

features_finetuned = torch.load("../data/features_finetuned_9k.pt")

kmeans_finetuned = KMeans(n_clusters=5, random_state=0)
clusters_finetuned = kmeans_finetuned.fit_predict(features_finetuned.cpu().numpy())

batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

images_dir = "../data/images_12/resized"
images = [os.path.join(images_dir, file) for file in os.listdir(images_dir) if file.endswith(".jpg")]

def get_image_names(directory):
    image_extensions = '.jpg'
    image_names = []
    for filename in sorted(os.listdir(directory)):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            image_names.append(filename)
    return image_names

directory = "../data/images_12/resized"
image_names = get_image_names(directory)

finetunning_matrix = []
for image, vector in zip(image_names, features_finetuned):
    finetunning_matrix.append([image, vector])
finetunning_matrix = np.array(finetunning_matrix, dtype=object)

def get_clusters(embs, kmeans):
	return kmeans.predict(embs.cpu().numpy())

def get_cluster_embs(embs, clusters, idx):
	same_cluster = np.where(clusters == clusters[idx])[0]
	return embs[same_cluster]

def get_similarity_embs(embs):
	vectors = embs / embs.norm(dim=1, keepdim=True)
	similarity_matrix = vectors @ vectors.t()
	return similarity_matrix.mean(dim=1)

def top_k_similar_images(embeddings_with_names, specific_image_name, k=10):
    specific_image_index = None
    for i, row in enumerate(embeddings_with_names):
        if row[0] == specific_image_name:
            specific_image_index = i
            break

    if specific_image_index is None:
        raise ValueError("The specific image name was not found in the data matrix.")
    specific_embedding = embeddings_with_names[specific_image_index][1].reshape(1, -1)
    all_embeddings = np.array([row[1] for row in embeddings_with_names])
    similarities = cosine_similarity(specific_embedding, all_embeddings)
    similar_image_indices = np.argsort(-similarities)[0][:k]
    similar_image_links = embeddings_with_names[similar_image_indices, 0]

    return similar_image_links.tolist()



In [62]:
top_k_similar_images(finetunning_matrix, "0.jpg", k=3)

['0.jpg', '1.jpg', '36.jpg']

In [73]:
clusters = get_clusters(features_finetuned, kmeans_finetuned)

def images_to_show(img_path):
	img_name = os.path.basename(img_path).replace(".jpeg", ".jpg")
	# idx = images.index(os.path.join(images_dir, img_name))
	# cluster_embs = get_cluster_embs(features_finetuned, clusters, idx)
	# similarity_embs = get_similarity_embs(features_finetuned)
	# sorted_idx = similarity_embs.argsort(descending=True)
	images_name = top_k_similar_images(finetunning_matrix, img_name, k=6)
	return [os.path.join(images_dir, images_name[i]) for i in range(6)]

def set_as_input(img_path):
    img = Image.open(img_path)
    blank = np.ones_like(img)*255
    return img_path, blank, blank, blank, blank, blank, blank

with gr.Blocks() as gui:
	with gr.Column():
		with gr.Row():
			with gr.Column():
				img_in = gr.Image(type="filepath")
				btn = gr.Button("Search")
		with gr.Row():
			with gr.Column():
				img_out1 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn1 = gr.Button("Set as input")
			with gr.Column():
				img_out2 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn2 = gr.Button("Set as input")
			with gr.Column():
				img_out3 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn3 = gr.Button("Set as input")
		with gr.Row():
			with gr.Column():
				img_out4 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn4 = gr.Button("Set as input")
			with gr.Column():
				img_out5 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn5 = gr.Button("Set as input")
			with gr.Column():
				img_out6 = gr.Image(show_download_button=False, interactive=False, type="filepath")
				btn6 = gr.Button("Set as input")
  
	btn.click(images_to_show, inputs=img_in, outputs=[img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn1.click(set_as_input, inputs=img_out1, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn2.click(set_as_input, inputs=img_out2, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn3.click(set_as_input, inputs=img_out3, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn4.click(set_as_input, inputs=img_out4, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn5.click(set_as_input, inputs=img_out5, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])
	btn6.click(set_as_input, inputs=img_out6, outputs=[img_in, img_out1, img_out2, img_out3, img_out4, img_out5, img_out6])

gui.launch()

Running on local URL:  http://127.0.0.1:7887

To create a public link, set `share=True` in `launch()`.


