In [1]:
#imports
import torch
import clip
from PIL import Image
import json
import glob

In [2]:
#load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [4]:
#open the json file for annotations
f = open('datasets/MSCOCO/annotations/captions_val2017.json')
data = json.load(f)

In [97]:
# get image ids, captions and captions ids in lists
image_ids = [x.get('id') for x in data['images']]
image_ids.sort()

captions = [x.get('caption') for x in data['annotations'] if x.get('image_id') in image_ids]
caption_ids = [x.get('image_id') for x in data['annotations'] if x.get('image_id') in image_ids]

image_ids = torch.Tensor(image_ids)

In [None]:
# load images
images = []
i = 0
for filename in glob.glob('datasets/MSCOCO/val/*.jpg'):
    im=preprocess(Image.open(filename)).unsqueeze(0).to(device)
    images.append(im)
    i+=1
    if (i%1000) == 0:
        print(i)
        
# turn list to tensor
b = torch.Tensor(5000, 3, 224, 224)
images = torch.cat(images, out=b)

In [None]:
# encode images
encoded_images = []
with torch.no_grad():
    for i in range(0, len(images), 32):
        im=model.encode_image(images[i:i+32])
        encoded_images.append(im)
        if (i%512) == 0:
            print(i)
            
# turn list to tensor
b = torch.Tensor(5000, 512)
encoded_images = torch.cat(encoded_images, out=b)
encoded_images /= encoded_images.norm(dim=-1, keepdim=True)

In [None]:
text = clip.tokenize(captions).to(device)

# encode captions
text_features = []
i = 0
with torch.no_grad():
    for i in range(0, len(text), 64): 
        t=model.encode_text(text[i:i+64])
        text_features.append(t)
        if (i%1024) == 0:
            print(i)
            
# turn list to tensor
b = torch.Tensor(25014, 512)
text_features = torch.cat(text_features, out=b)
text_features /= text_features.norm(dim=-1, keepdim=True)

encoded_captions = list(zip(caption_ids, text_features))

In [91]:
recall_1 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(1)

    recall_1.append(image_id in image_ids[indices])

recall_1 = torch.Tensor(recall_1)

recall_1.mean()

tensor(0.3036)

In [92]:
recall_5 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(5)

    recall_5.append(image_id in image_ids[indices])

recall_5 = torch.Tensor(recall_5)

recall_5.mean()

tensor(0.5478)

In [93]:
recall_10 = []
for image_id, text_feature in encoded_captions: 
    similarity = (100.0 * text_feature @ encoded_images.T).softmax(dim=-1)
    values, indices = similarity.topk(10)

    recall_10.append(image_id in image_ids[indices])

recall_10 = torch.Tensor(recall_10)

recall_10.mean()

tensor(0.6609)

In [94]:
recall_1 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(1)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_1.append(image_ids[i] in im_ids)
    
recall_1 = torch.Tensor(recall_1)

recall_1.mean()

tensor(0.5002)

In [95]:
recall_5 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(5)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_5.append(image_ids[i] in im_ids)
    
recall_5 = torch.Tensor(recall_5)

recall_5.mean()

tensor(0.7498)

In [96]:
recall_10 = []
for i, image in enumerate(encoded_images): 
    similarity = (100.0 * image @ text_features.T).softmax(dim=-1)
    values, indices = similarity.topk(10)
    im_ids = []
    for indice in indices:
        im_ids.append(caption_ids[indice])
    recall_10.append(image_ids[i] in im_ids)
    
recall_10 = torch.Tensor(recall_10)

recall_10.mean()

tensor(0.8328)

Old code:

In [None]:
# load images and encode them 1 by 1, (slow)
images = []
i = 0
with torch.no_grad():
    for filename in glob.glob('datasets/MSCOCO/val/*.jpg'):
        im=preprocess(Image.open(filename)).unsqueeze(0).to(device)
        im=model.encode_image(im)
        images.append(im)
        i+=1
        if (i%100) == 0:
            print(i)
        if i == 1000:
            break
            
# turn list to tensor
b = torch.Tensor(1000, 512)
images = torch.cat(images, out=b)
images /= images.norm(dim=-1, keepdim=True)

In [None]:
text = clip.tokenize(captions).to(device)

# encode captions 1 by 1
text_features = []
i = 0
with torch.no_grad():
    for caption in text:
        t=model.encode_text(caption.reshape(1,77))
        t /= t.norm(dim=-1, keepdim=True)
        text_features.append(t)
        i+=1
        if (i%100) == 0:
            print(i)
        if i == 5003:
            break
            
# turn list to tensor
b = torch.Tensor(5003, 512)
text_features = torch.cat(text_features, out=b)

captions = list(zip(caption_ids, text_features))