In [None]:
import torch
import os
from PIL import Image
import clip
import torch.nn.functional as F
import csv
import matplotlib.pyplot as plt

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

In [None]:
caption_embs = torch.stack(list(caption_model.values()))  # [8000, 512]
caption_img_names = [cap[0] for cap in caption_model]
print(f"Number of captions: {len(caption_img_names)}") 

In [None]:
caption_embs = caption_embs.squeeze(1)

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

In [None]:
# Image to image
test_text_feature = []

# prepocess the text and move it to the device
test_image_preprocessed = "A black and white dog is running through the grass"

text_input = clip.tokenize([test_image_preprocessed], truncate=True).to(device)

# Get the image feature using the model.encode_image function
with torch.no_grad():
    test_text_feature = model.encode_text(text_input) # output shape torch.Size([1, 512])

    # Normalize the image feature
    test_text_feature /= test_text_feature.norm(dim=-1, keepdim=True)

test_text_feature.shape # torch.Size([1, 512])

In [None]:
print(test_text_feature.shape)     #torch.Size([1, 512])
print(caption_embs.squeeze(1).shape) #torch.Size([40455, 512])

In [None]:
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt" 
captions = []

with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    next(reader)  # skip header: image,caption

    for row in reader:
        if len(row) < 2:
            continue
        img_name, caption = row
        captions.append((img_name.strip(), caption.strip()))

print("Total captions:", len(captions))

In [None]:
sims = F.cosine_similarity(test_text_feature, caption_embs.squeeze(1).to(device)).squeeze()
print(sims.shape)  #torch.Size([40455])

# Get the top 5 most similar images
topk = sims.topk(5).indices

# Display the top 5 most similar images
print(topk)
print(topk.shape)

retrieved_text = [captions[j] for j in topk]
print(f"input_text: {test_image_preprocessed}")
print("Retrieved captions:")

plt.figure(figsize=(15, 5))
i = 0
for img_name, caption in retrieved_text:
    print(f"{img_name}: {caption}")
    
    img_path = os.path.join(f"{datasets}/flickr8k/Images", img_name)
    img = Image.open(img_path).convert("RGB")
    
    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Sim: {sims[topk[i]]:.4f}")
    i += 1

plt.show()