In [22]:
device = 0
results_name = 'CLIP-ViT-B-16'
clip_model = 'CLIP-ViT-B/16' # Options: ['CLIP-ViT-B/16', 'CLIP-ViT-B/32', 'CLIP-ViT-L/14', 'CLIP-ViT-L/14@336']

out_dir = '/proj/vondrick4/naveen/coir-ret-results'
dataset_split = 'val'
lasco_data_path = '/local/vondrick/naveen/coir-data/LaSCo'
device_map = 'cuda:{}'.format(device)

In [23]:
clip_checkpoints = {
    'CLIP-ViT-B/16': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch16',
    'CLIP-ViT-B/32': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch32',
    'CLIP-ViT-L/14': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-large-patch14',
    'CLIP-ViT-L/14@336': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-large-patch14-336'
}

In [24]:
import sys
sys.path.append('/proj/vondrick4/naveen/CoIR')
import warnings
warnings.filterwarnings("ignore")
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection
import faiss
import torch
import numpy as np
import os
import json

from src.datasets.lasco_corpus_dataset import lasco_corpus_dataset_clip
from src.datasets.lasco_retrieval_dataset import lasco_retrieval_dataset_clip
from src.metrics.metrics import calculate_recall

In [4]:
os.mkdir(os.path.join(out_dir, results_name))

FileExistsError: [Errno 17] File exists: '/proj/vondrick4/naveen/coir-ret-results/CLIP-ViT-B-16'

In [5]:
clip_checkpoint_path = clip_checkpoints[clip_model]

In [6]:
print('Using device: {}'.format(device_map))

Using device: cuda:0


In [7]:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path=clip_checkpoint_path, local_files_only=True).to(device)
text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path=clip_checkpoint_path, local_files_only=True).to(device)

image_encoder.eval()
text_encoder.eval()
print('Model loaded')

Model loaded


In [8]:
d = image_encoder.config.projection_dim
index = faiss.IndexFlatIP(d)

res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, device, index)

In [9]:
corpus_dataset = lasco_corpus_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
corpus_dataloader = DataLoader(
    dataset=corpus_dataset,
    collate_fn=corpus_dataset.collate_fn,
    batch_size=100,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [10]:
retrieval_dataset = lasco_retrieval_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
retrieval_dataloader = DataLoader(
    dataset=retrieval_dataset,
    collate_fn=retrieval_dataset.collate_fn,
    batch_size=100,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [11]:
# Create embeddings of images in the corpus
index_cntr = 0
index_id_to_image_id_map = {}

for batch_idx, batch in enumerate(tqdm(corpus_dataloader, desc="Indexing Corpus")):
    with torch.no_grad():
        batch['image']['pixel_values'] = batch['image']['pixel_values'].to(device_map)
        image_embeds = image_encoder(**batch['image']).image_embeds
        image_embeds = image_embeds / torch.linalg.vector_norm(image_embeds, ord=2, dim=1,keepdim=True)
        index.add(image_embeds.cpu())

        batch_len = len(batch['image-key'])
        batch_start_indx = index_cntr
        batch_end_indx = batch_start_indx + batch_len

        for key, value in zip(list(range(batch_start_indx, batch_end_indx)), batch['image-key']):
            index_id_to_image_id_map[key] = value
        index_cntr += batch_len

Indexing Corpus: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 399/399 [02:35<00:00,  2.56it/s]


In [12]:
map_func = np.vectorize(lambda x: index_id_to_image_id_map[x])

In [13]:
output = []

for batch_idx, batch in enumerate(tqdm(retrieval_dataloader , desc="Retrieval Task")):
    with torch.no_grad():
        batch['query-image']['pixel_values'] = batch['query-image']['pixel_values'].to(device_map)
        batch['query-text']['input_ids'] = batch['query-text']['input_ids'].to(device_map)
        batch['query-text']['attention_mask'] = batch['query-text']['attention_mask'].to(device_map)
        
        query_image_embeds = image_encoder(**batch['query-image']).image_embeds
        query_image_embeds = query_image_embeds / torch.linalg.vector_norm(query_image_embeds, ord=2, dim=1,keepdim=True)
        
        query_text_embeds = text_encoder(**batch['query-text']).text_embeds
        query_text_embeds = query_text_embeds / torch.linalg.vector_norm(query_text_embeds, ord=2, dim=1,keepdim=True)

    target_hat_embeds = query_image_embeds + query_text_embeds
    target_hat_embeds = target_hat_embeds / torch.linalg.vector_norm(target_hat_embeds, ord=2, dim=1, keepdim=True)

    D, I = index.search(target_hat_embeds.cpu(), k=50)
    I = map_func(I)

    batch_size = len(batch['query-image-id'])
    for i in range(batch_size):
        output.append({
            'id': batch['id'][i],
            'query-image-id': batch['query-image-id'][i],
            'target-image-id': batch['target-image-id'][i],
            'query-text-raw': batch['query-text-raw'][i],
            'top_50_ret_cands': I[i][:].tolist(),
            'top_50_ret_cands_cos_sims': D[i][:].tolist()
        })

Retrieval Task:   9%|█████████▊                                                                                                       | 26/301 [00:13<02:23,  1.92it/s]


KeyboardInterrupt: 

In [14]:
with open(os.path.join(out_dir, results_name, 'outputs'+'-lasco-'+dataset_split+'.json'), "w") as json_file:
    json.dump(output, json_file, indent=4)

In [21]:
metrics = []
ground_truths = np.array(list(map(lambda x: x['target-image-id'], output)))[:, np.newaxis]
retrieved_candidates = np.array(list(map(lambda x: x['top_50_ret_cands'], output)))

metrics.append({"Recall@1": 100*calculate_recall(ground_truths, retrieved_candidates, 1)})
metrics.append({"Recall@5": 100*calculate_recall(ground_truths, retrieved_candidates, 5)})
metrics.append({"Recall@10": 100*calculate_recall(ground_truths, retrieved_candidates, 10)})
metrics.append({"Recall@50": 100*calculate_recall(ground_truths, retrieved_candidates, 50)})

with open(os.path.join(out_dir, results_name, 'metrics'+'-lasco-'+dataset_split+'.json'), "w") as json_file:
    json.dump(metrics, json_file, indent=4)