In [None]:
import torch
import os
from PIL import Image
import clip
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
test_images_dir = "..\\test_images\\"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
image_model = torch.load(f"{models}/image_features_flickr8k.pt")

In [None]:
image_names = list(image_model.keys())
image_features = torch.stack(list(image_model.values()))

print(f"Number of images: {len(image_names)}")
print(f"Image features shape: {image_features.shape}")

In [None]:
model, preprocess = clip.load("ViT-B/16", device=device)
model.eval()

In [None]:
# Image to image
test_image_feature = []

# Get the path of the image you want to test
test_image_path = f"{test_images_dir}/test.jpg"

# Covert the image to a RGB image
test_image = Image.open(test_image_path).convert("RGB")

# prepocess the image and move it to the device
test_image_preprocessed = preprocess(test_image).unsqueeze(0).to(device)

# Get the image feature using the model.encode_image function
with torch.no_grad():
    test_image_feature = model.encode_image(test_image_preprocessed) # output shape torch.Size([1, 512])
    
    # Normalize the image feature
    test_image_feature /= test_image_feature.norm(dim=-1, keepdim=True)

test_image_feature.shape # torch.Size([1, 512])

In [None]:
print(test_image_feature.shape)     #torch.Size([1, 1, 512])
print(image_features.squeeze(1).shape) #torch.Size([8091, 1, 512])

In [None]:
image_names[0]

In [None]:
sims = F.cosine_similarity(test_image_feature, image_features.squeeze(1).to(device)).squeeze()
print(sims.shape)  #torch.Size([8091])

# Get the top 5 most similar images
topk = sims.topk(5).indices

# Display the top 5 most similar images
print(topk)
print(topk.shape)

plt.figure(figsize=(15, 5))

retrieved_imgs = [image_names[j] for j in topk]

for i, img_name in enumerate(retrieved_imgs):
    print(img_name)
    
    # img_path = os.path.join("./data/flickr30k/flickr30k_images/flickr30k_images/", img_name)
    img_path = os.path.join(f"{datasets}/flickr8k/Images", img_name)
    img = Image.open(img_path).convert("RGB")
    
    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Sim: {sims[topk[i]]:.4f}")