In [64]:
#imports
import torch
import clip
from PIL import Image
import json
import glob
import numpy as np

In [78]:
#load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [178]:
#open captions
f = open('datasets/flickr30k/results_20130124.token', encoding="utf8")
caption_ids = []
captions = []
for i in f:
    cap = False
    skip = 0
    caption_id = ""
    caption = ""
    for j in i:
        if j == '\n':
            break
        if skip > 0:
            skip -= 1
            continue
        if cap:
            caption += j
            continue
        if j == '.':
            cap = True
            skip = 6
        else:
            caption_id += j
    
    caption_ids.append(int(caption_id))
    captions.append(caption)

In [225]:
# get image ids, captions and captions ids in lists
image_ids = list(dict.fromkeys(caption_ids))
image_ids.sort()

zipped_lists = zip(caption_ids, captions)
sorted_pairs = sorted(zipped_lists)
tuples = zip(*sorted_pairs)
caption_ids, captions = [list(tuple) for tuple in tuples]

caption_ids = caption_ids[0:5000]
captions = captions[0:5000]
image_ids = image_ids[0:1000]

In [None]:
# load images
images = []
i = 0
for id in image_ids:
    filename = 'datasets/flickr30k/flickr30k-images/' + str(id) + '.jpg'
    im=preprocess(Image.open(filename)).unsqueeze(0).to(device)
    images.append(im)
    i+=1
    if (i%100) == 0:
        print(i)
    if i == 1000:
            break
        
# turn list to tensor
b = torch.Tensor(1000, 3, 224, 224)
images = torch.cat(images, out=b)
image_ids = torch.Tensor(image_ids)

In [227]:
# encode images
encoded_images = []
with torch.no_grad():
    for i in range(0, len(images), 32):
        im=model.encode_image(images[i:i+32])
        encoded_images.append(im)
        if (i%256) == 0:
            print(i)
            
# turn list to tensor
b = torch.Tensor(1000, 512)
encoded_images = torch.cat(encoded_images, out=b)
encoded_images /= encoded_images.norm(dim=-1, keepdim=True)

0
256
512
768


In [None]:
text = clip.tokenize(captions).to(device)

# encode captions
text_features = []
i = 0
with torch.no_grad():
    for i in range(0, len(text), 32): 
        t=model.encode_text(text[i:i+32])
        text_features.append(t)
        if (i%512) == 0:
            print(i)
            
# turn list to tensor
b = torch.Tensor(5000, 512)
text_features = torch.cat(text_features, out=b)
text_features /= text_features.norm(dim=-1, keepdim=True)

encoded_captions = list(zip(caption_ids, text_features))

In [228]:
recall_1 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(1)

    recall_1.append(image_id in image_ids[indices])

recall_1 = torch.Tensor(recall_1)

recall_1.mean()

tensor(0.5816)

In [229]:
recall_5 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(5)

    recall_5.append(image_id in image_ids[indices])

recall_5 = torch.Tensor(recall_5)

recall_5.mean()

tensor(0.8302)

In [230]:
recall_10 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(10)

    recall_10.append(image_id in image_ids[indices])

recall_10 = torch.Tensor(recall_10)

recall_10.mean()

tensor(0.9016)

In [231]:
recall_1 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(1)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_1.append(image_ids[i] in im_ids)
    
recall_1 = torch.Tensor(recall_1)

recall_1.mean()

tensor(0.7730)

In [232]:
recall_5 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(5)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_5.append(image_ids[i] in im_ids)
    
recall_5 = torch.Tensor(recall_5)

recall_5.mean()

tensor(0.9360)

In [233]:
recall_10 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(10)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_10.append(image_ids[i] in im_ids)
    
recall_10 = torch.Tensor(recall_10)

recall_10.mean()

tensor(0.9750)