In [14]:
import torch
import os
from PIL import Image
import clip
import torch.nn.functional as F
import csv

In [2]:
caption_model = torch.load('caption_features_flickr8k.pt')

In [8]:
caption_embs = torch.stack([cap[2] for cap in caption_model])  # shape: [N_cap, 512]
caption_img_names = [cap[0] for cap in caption_model]
print(f"Number of captions: {len(caption_img_names)}") 

Number of captions: 40455


In [9]:
caption_embs = caption_embs.squeeze(1)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [11]:
# Image to image
test_text_feature = []
# prepocess the text and move it to the device
test_image_preprocessed = "A black and white dog is running through the grass"

text_input = clip.tokenize([test_image_preprocessed], truncate=True).to(device)
# Get the image feature using the model.encode_image function
with torch.no_grad():
    test_text_feature = model.encode_text(text_input) # output shape torch.Size([1, 512])
# Normalize the image feature
    test_text_feature /= test_text_feature.norm(dim=-1, keepdim=True)
test_text_feature.shape # torch.Size([1, 512])

torch.Size([1, 512])

In [None]:
print(test_text_feature.shape)     #torch.Size([1, 512])
print(caption_embs.squeeze(1).shape) #torch.Size([40455, 512])

torch.Size([1, 512])
torch.Size([40455, 512])


In [15]:
CAPTIONS_PATH ="../datasets/flickr8k/captions.txt"
captions = []

with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    next(reader)  # skip header: image,caption

    for row in reader:
        if len(row) < 2:
            continue
        img_name, caption = row
        captions.append((img_name.strip(), caption.strip()))

print("Total captions:", len(captions))

Total captions: 40455


In [16]:
sims = F.cosine_similarity(test_text_feature, caption_embs.squeeze(1)).squeeze()
print(sims.shape)  #torch.Size([40455])
# Get the top 5 most similar images
topk = sims.topk(5).indices
# Display the top 5 most similar images
print(topk)
print(topk.shape)

retrieved_text = [captions[j] for j in topk]
print("Retrieved captions:")
for img_name, caption in retrieved_text:
    print(f"{img_name}: {caption}")

torch.Size([40455])
tensor([12522,    31, 13940,  1485, 14771])
torch.Size([5])
Retrieved captions:
2637904605_fc355816fc.jpg: A black and white dog is running through the grass .
1009434119_febe49276a.jpg: A black and white dog is running through the grass .
2730994020_64ac1d18be.jpg: A black and white dog is running in the grass .
1330645772_24f831ff8f.jpg: A black and white dog is running in the grass .
2792409624_2731b1072c.jpg: A black and white dog is running through grass .
