In [10]:
device = 1
results_name = 'CLIP-ViT-B-32'
clip_model = 'CLIP-ViT-B/32' # Options: ['CLIP-ViT-B/16', 'CLIP-ViT-B/32', 'CLIP-ViT-L/14', 'CLIP-ViT-L/14@336']
batch_size=64

out_dir = '/proj/vondrick4/naveen/coir-ret-results'
dataset_split = 'test'
lasco_data_path = '/proj/vondrick4/naveen/coir-data/flickr-30k'
device_map = 'cuda:{}'.format(device)


clip_checkpoints = {
    'CLIP-ViT-B/16': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch16',
    'CLIP-ViT-B/32': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-base-patch32',
    'CLIP-ViT-L/14': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-large-patch14',
    'CLIP-ViT-L/14@336': '/local/vondrick/naveen/pretrained_models/clip/clip-vit-large-patch14-336'
}

In [5]:
import sys
sys.path.append('/proj/vondrick4/naveen/CoIR')
import warnings
warnings.filterwarnings("ignore")
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection
import faiss
import torch
import numpy as np
import os
import json

from src.datasets.flickr_30k.flickr_30k_corpus_dataset import mscoco_5k_image_corpus_dataset_clip, mscoco_5k_text_corpus_dataset_clip
from src.datasets.flickr_30k.flickr_30k_retrieval_dataset import mscoco_5k_retrieval_dataset_clip
from src.metrics.metrics import calculate_recall

In [6]:
os.makedirs(os.path.join(out_dir, results_name), exist_ok=True)
clip_checkpoint_path = clip_checkpoints[clip_model]
print('Using device: {}'.format(device_map))

Using device: cuda:1


In [7]:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path=clip_checkpoint_path, local_files_only=True).to(device)
text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path=clip_checkpoint_path, local_files_only=True).to(device)

image_encoder.eval()
text_encoder.eval()
print('Model loaded')

Model loaded


In [8]:
d = image_encoder.config.projection_dim
image_index = faiss.IndexFlatIP(d)
text_index = faiss.IndexFlatIP(d)

res = faiss.StandardGpuResources()
image_index = faiss.index_cpu_to_gpu(res, device, image_index)
text_index = faiss.index_cpu_to_gpu(res, device, text_index)

In [11]:
corpus_dataset_image = mscoco_5k_image_corpus_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
corpus_dataloader_image = DataLoader(
    dataset=corpus_dataset_image,
    collate_fn=corpus_dataset_image.collate_fn,
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [12]:
corpus_dataset_text = mscoco_5k_text_corpus_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
corpus_dataloader_text = DataLoader(
    dataset=corpus_dataset_text,
    collate_fn=corpus_dataset_text.collate_fn,
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [13]:
retrieval_dataset = mscoco_5k_retrieval_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
retrieval_dataloader = DataLoader(
    dataset=retrieval_dataset,
    collate_fn=retrieval_dataset.collate_fn,
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [14]:
index_cntr = 0
index_id_to_image_id_map = {}

for batch_idx, batch in enumerate(tqdm(corpus_dataloader_image, desc="Indexing Corpus: IMAGES")):
    with torch.no_grad():
        batch['image']['pixel_values'] = batch['image']['pixel_values'].to(device_map)
        image_embeds = image_encoder(**batch['image']).image_embeds
        image_embeds = image_embeds / torch.linalg.vector_norm(image_embeds, ord=2, dim=1,keepdim=True)
    
    image_index.add(image_embeds.cpu())

    batch_len = len(batch['image-key'])
    batch_start_indx = index_cntr
    batch_end_indx = batch_start_indx + batch_len

    for key, value in zip(list(range(batch_start_indx, batch_end_indx)), batch['image-key']):
        index_id_to_image_id_map[key] = value
    index_cntr += batch_len

Indexing Corpus: IMAGES: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:03<00:00,  4.51it/s]


In [15]:
index_cntr = 0
index_id_to_text_id_map = {}

for batch_idx, batch in enumerate(tqdm(corpus_dataloader_text, desc="Indexing Corpus: TEXT")):
    with torch.no_grad():
        batch['text']['input_ids'] = batch['text']['input_ids'].to(device_map)
        batch['text']['attention_mask'] = batch['text']['attention_mask'].to(device_map)

        text_embeds = text_encoder(**batch['text']).text_embeds
        text_embeds = text_embeds / torch.linalg.vector_norm(text_embeds, ord=2, dim=1,keepdim=True)
    
    text_index.add(text_embeds.cpu())

    batch_len = len(batch['text-key'])
    batch_start_indx = index_cntr
    batch_end_indx = batch_start_indx + batch_len

    for key, value in zip(list(range(batch_start_indx, batch_end_indx)), batch['text-key']):
        index_id_to_text_id_map[key] = value
    index_cntr += batch_len

Indexing Corpus: TEXT: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:01<00:00, 15.77it/s]


In [16]:
map_func_image = np.vectorize(lambda x: index_id_to_image_id_map[x])
map_func_text = np.vectorize(lambda x: index_id_to_text_id_map[x])

In [17]:
output_img_2_text = []

retrieval_dataset = mscoco_5k_retrieval_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
retrieval_dataloader = DataLoader(
    dataset=retrieval_dataset,
    collate_fn=retrieval_dataset.collate_fn,
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

for batch_idx, batch in enumerate(tqdm(retrieval_dataloader , desc="Retrieval Task: IMG-2-TEXT")):
    with torch.no_grad():
        batch['image']['pixel_values'] = batch['image']['pixel_values'].to(device_map)
        image_embeds = image_encoder(**batch['image']).image_embeds
        image_embeds = image_embeds / torch.linalg.vector_norm(image_embeds, ord=2, dim=1, keepdim=True)

    D, I = text_index.search(image_embeds.cpu(), k=1000)
    I = map_func_text(I)

    batch_size = len(batch['image-id'])
    for i in range(batch_size):
        output_img_2_text.append({
            'id': batch['id'][i],
            'image-id': batch['image-id'][i],
            'text-id': batch['text-id'][i],
            'text-raw': batch['text-raw'][i],
            'top_1000_ret_cands': I[i][:].tolist(),
            'top_1000_ret_cands_cos_sims': D[i][:].tolist()
            })

with open(os.path.join(out_dir, results_name, 'outputs'+'-mscoco-5k-[image-2-text]-'+dataset_split+'.json'), "w") as json_file:
    json.dump(output_img_2_text, json_file, indent=4)


metrics = []
ground_truths = np.array(list(map(lambda x: x['text-id'], output_img_2_text)))
retrieved_candidates = np.array(list(map(lambda x: x['top_1000_ret_cands'], output_img_2_text)))

metrics.append({"Recall@1": 100*calculate_recall(ground_truths, retrieved_candidates, 1)})
metrics.append({"Recall@5": 100*calculate_recall(ground_truths, retrieved_candidates, 5)})
metrics.append({"Recall@10": 100*calculate_recall(ground_truths, retrieved_candidates, 10)})
metrics.append({"Recall@50": 100*calculate_recall(ground_truths, retrieved_candidates, 50)})
metrics.append({"Recall@100": 100*calculate_recall(ground_truths, retrieved_candidates, 100)})
metrics.append({"Recall@500": 100*calculate_recall(ground_truths, retrieved_candidates, 500)})
metrics.append({"Recall@1000": 100*calculate_recall(ground_truths, retrieved_candidates, 1000)})

with open(os.path.join(out_dir, results_name, 'metrics'+'-mscoco-5k-[image-2-text]-'+dataset_split+'.json'), "w") as json_file:
    json.dump(metrics, json_file, indent=4)

Retrieval Task: IMG-2-TEXT: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:02<00:00,  7.73it/s]


In [None]:
metrics

In [18]:
output_text_2_img = []

retrieval_dataset = mscoco_5k_retrieval_dataset_clip(dataset_split, lasco_data_path, clip_checkpoint_path)
retrieval_dataloader = DataLoader(
    dataset=retrieval_dataset,
    collate_fn=retrieval_dataset.collate_fn,
    batch_size=batch_size,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

for batch_idx, batch in enumerate(tqdm(retrieval_dataloader , desc="Retrieval Task: IMG-2-TEXT")):
    with torch.no_grad():
        batch['text']['input_ids'] = batch['text']['input_ids'].to(device_map)
        batch['text']['attention_mask'] = batch['text']['attention_mask'].to(device_map)
        
        text_embeds = text_encoder(**batch['text']).text_embeds
        text_embeds = text_embeds / torch.linalg.vector_norm(text_embeds, ord=2, dim=1, keepdim=True)

    D, I = image_index.search(text_embeds.cpu(), k=1000)
    I = map_func_image(I)

    batch_size = len(batch['text-id'])
    for i in range(batch_size):
        output_text_2_img.append({
            'id': batch['id'][i],
            'image-id': batch['image-id'][i],
            'text-id': batch['text-id'][i],
            'text-raw': batch['text-raw'][i],
            'top_1000_ret_cands': I[i][:].tolist(),
            'top_1000_ret_cands_cos_sims': D[i][:].tolist()
            })

with open(os.path.join(out_dir, results_name, 'outputs'+'-mscoco-5k-[text-2-image]-'+dataset_split+'.json'), "w") as json_file:
    json.dump(output_text_2_img, json_file, indent=4)


metrics = []
ground_truths = np.array(list(map(lambda x: x['image-id'], output_text_2_img)))
retrieved_candidates = np.array(list(map(lambda x: x['top_1000_ret_cands'], output_text_2_img)))

metrics.append({"Recall@1": 100*calculate_recall(ground_truths, retrieved_candidates, 1)})
metrics.append({"Recall@5": 100*calculate_recall(ground_truths, retrieved_candidates, 5)})
metrics.append({"Recall@10": 100*calculate_recall(ground_truths, retrieved_candidates, 10)})
metrics.append({"Recall@50": 100*calculate_recall(ground_truths, retrieved_candidates, 50)})
metrics.append({"Recall@100": 100*calculate_recall(ground_truths, retrieved_candidates, 100)})
metrics.append({"Recall@500": 100*calculate_recall(ground_truths, retrieved_candidates, 500)})
metrics.append({"Recall@1000": 100*calculate_recall(ground_truths, retrieved_candidates, 1000)})

with open(os.path.join(out_dir, results_name, 'metrics'+'-mscoco-5k-[text-2-image]-'+dataset_split+'.json'), "w") as json_file:
    json.dump(metrics, json_file, indent=4)

Retrieval Task: IMG-2-TEXT: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:02<00:00, 12.17it/s]


In [19]:
metrics

[{'Recall@1': 67.10000000000001},
 {'Recall@5': 89.0},
 {'Recall@10': 93.8},
 {'Recall@50': 98.8},
 {'Recall@100': 99.5},
 {'Recall@500': 100.0},
 {'Recall@1000': 100.0}]