In [None]:
import torch
import open_clip
import torch.nn.functional as F

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

In [None]:
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")
image_model = torch.load(f"{models}/image_features_flickr8k.pt")

In [None]:
image_names = list(image_model.keys())
image_embs = torch.stack([v.squeeze(0) for v in image_model.values()])  # [8000, 512]

caption_img_names = list(caption_model.keys())
caption_embs = torch.stack(list(caption_model.values()))  # [8000, 512]

In [None]:
caption_embs = caption_embs.squeeze(1)
image_embs = image_embs.squeeze(1)

In [None]:

print("caption_embs shape:", caption_embs.shape)   # should be [num_captions, 512]
print("image_embs shape:", image_embs.shape)       # should be [num_images, 512]
print("caption_embs[0] shape:", caption_embs[0].shape)  # should be [512]
print("caption_embs[0].unsqueeze(0) shape:", caption_embs[0].unsqueeze(0).shape)  # [1,512]

In [None]:
image_embs = image_embs.to(device)
caption_embs = caption_embs.to(device)

In [None]:
# Text to image Recall
def text_to_image_recall(image_embs, caption_embs, caption_img_names, image_names, K=10):
    recalls = 0
    total = len(caption_embs)

    for i in range(total):
        sims = F.cosine_similarity(caption_embs[i].unsqueeze(0), image_embs).squeeze()
        topk = sims.topk(K).indices
        
        retrieved_imgs = [image_names[j] for j in topk]
        true_img = caption_img_names[i]

        if true_img in retrieved_imgs:
            recalls += 1

    print(f"Text to Image Recall@{K}: {recalls/total:.4f}")
    return recalls / total

In [None]:
# Image to text Recall
def image_to_text_recall(image_embs, caption_embs, caption_img_names, image_names, K=10):
    recalls = 0
    total = len(image_embs)

    for i in range(total):
        sims = F.cosine_similarity( image_embs[i].unsqueeze(0), caption_embs).squeeze()
        topk = sims.topk(K).indices
        
        retrieved_captions = [caption_img_names[j] for j in topk]
        true_caption = image_names[i]

        if true_caption in retrieved_captions:
            recalls += 1

    print(f"Image to Text Recall@{K}: {recalls/total:.4f}")
    return recalls / total

In [None]:
for k in [1, 5, 10]:
    text_to_image_recall(image_embs, caption_embs, caption_img_names, image_names, K=k)

for k in [1, 5, 10]:
    image_to_text_recall(image_embs, caption_embs, caption_img_names, image_names, K=k)