In [1]:
import os
import sys
import json
import faiss
import copy
import yaml
import numpy as np
import torch
from torch.utils.data import DataLoader
from lightning.pytorch import seed_everything
from tqdm import tqdm
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection

sys.path.append('/local/vondrick/nd2794/CoIR')
from src.datasets.lasco_datasets import lasco_corpus_dataset, lasco_retrieval_dataset
from src.models.clip.clip import CLIPModel

In [2]:
config  = copy.deepcopy(yaml.safe_load(open('/local/vondrick/nd2794/CoIR/configs/retriever_config.yaml', 'r')))

In [6]:
# = 'cuda:{}'.format(config['model_gpu_device_id'])
model_device = torch.device("cuda:{}".format(config['model_gpu_device_id']) if torch.cuda.is_available() else "cpu")

if config['faiss_use_gpu'] == True:
    faiss_device = 'cuda:{}'.format(config['faiss_gpu_device_id'])
else:
    faiss_device = 'cpu'

print('Model device: {}'.format(model_device))
print('FAISS device: {}'.format(faiss_device))

Model device: cuda:0
FAISS device: cuda:1


In [7]:
model = CLIPModel.load_from_checkpoint(config['pl_ckpt_path'], map_location = model_device)
#model = CLIPModel.load_from_checkpoint(config['pl_ckpt_path'], map_location = 'cpu')

if config['eval_model_type'] == 'baseline':
    model.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_model_name_or_path = config['checkpoint_path'], local_files_only = True)
    model.text_encoder = CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path = config['checkpoint_path'], local_files_only = True)

model.eval()
print('Model loaded')

Model loaded


In [8]:
d = config['d']
index = faiss.IndexFlatIP(d)

if config['faiss_use_gpu'] == True:
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, config['faiss_gpu_device_id'], index)

In [9]:
corpus_dataset = lasco_corpus_dataset(config)
corpus_dataloader = DataLoader(
    dataset=corpus_dataset, 
    collate_fn=corpus_dataset.collate_fn,
    batch_size=config['dataloader']['batch_size'], 
    shuffle=config['dataloader']['shuffle'],
    num_workers=config['dataloader']['num_workers'], 
    pin_memory=config['dataloader']['pin_memory'],
    drop_last=config['dataloader']['drop_last'],
    persistent_workers=config['dataloader']['persistent_workers']
)

In [10]:
retrieval_dataset = lasco_retrieval_dataset(config)
retrieval_dataloader = DataLoader(
    dataset=retrieval_dataset, 
    collate_fn=retrieval_dataset.collate_fn,
    batch_size=config['dataloader']['batch_size'], 
    shuffle=config['dataloader']['shuffle'],
    num_workers=config['dataloader']['num_workers'], 
    pin_memory=config['dataloader']['pin_memory'],
    drop_last=config['dataloader']['drop_last'],
    persistent_workers=config['dataloader']['persistent_workers']
)

In [15]:
index_cntr = 0
index_id_to_image_id_map = {}

for batch_idx, batch in tqdm(enumerate(corpus_dataloader)):
    with torch.no_grad():
        batch['image']['pixel_values'] = batch['image']['pixel_values'].to(model_device)
        image_embeds = model.image_forward(batch)
    index.add(image_embeds['image-embeds'].cpu().data.numpy())

    batch_len = len(batch['image-key'])
    batch_start_indx = index_cntr
    batch_end_indx = batch_start_indx + batch_len
    
    for key, value in zip(list(range(batch_start_indx, batch_end_indx)), batch['image-key']):
        index_id_to_image_id_map[key] = value
    index_cntr += batch_len

80it [02:23,  1.80s/it]


KeyboardInterrupt: 

In [16]:
map_func = np.vectorize(lambda x: index_id_to_image_id_map[x])

In [17]:
results_table = np.zeros((1, 52), dtype = 'int64')

In [22]:
for batch_idx, batch in tqdm(enumerate(retrieval_dataloader)):
    with torch.no_grad():
        batch['query-image']['pixel_values'] = batch['query-image']['pixel_values'].to(model_device)
        batch['query-text']['input_ids'] = batch['query-text']['input_ids'].to(model_device)
        batch['query-text']['attention_mask'] = batch['query-text']['attention_mask'].to(model_device)
        outs = model.retriever_forward(batch)
    
    target_hat_embeds = outs['query_image_embeds'] + outs['query_text_embeds']
    target_hat_embeds = target_hat_embeds/torch.linalg.vector_norm(target_hat_embeds, ord=2, dim=1,  keepdim=True)

    D, I = index.search(target_hat_embeds.cpu().data.numpy(), 50)
    I = map_func(I)

    batch_result_table = np.concatenate(
        (
            np.array(batch['query-image-id'], dtype = 'int64').reshape(-1, 1), 
            np.array(batch['target-image-id'], dtype = 'int64').reshape(-1, 1), 
            I
        ), axis = 1)

    results_table = np.concatenate((results_table, batch_result_table), axis=0)

40it [01:19,  1.99s/it]


KeyboardInterrupt: 

In [23]:
results_table = results_table[1:, :]

In [24]:
def calculate_recall(data, k):
    """
    Calculate recall@k for the given dataset.
    
    :param data: numpy array of shape (n_samples, 52) 
                 where first 2 columns are key and ground truth candidate
                 and next 50 columns are top 50 retrieved candidates.
    :param k: top k retrieved candidates to consider for recall.
    :return: recall@k value
    """
    keys = data[:, 0]
    ground_truths = data[:, 1]
    retrieved_candidates = data[:, 2:]

    # Check if the ground truth is within the top k retrieved candidates
    correct_predictions = np.any(retrieved_candidates[:, :k] == ground_truths[:, np.newaxis], axis=1)
    
    # Calculate recall@k as the mean of correct predictions
    recall_at_k = np.mean(correct_predictions)
    
    return recall_at_k

In [25]:
data = results_table

In [26]:
recall_1 = calculate_recall(data, 1)
recall_5 = calculate_recall(data, 5)
recall_10 = calculate_recall(data, 10)
recall_50 = calculate_recall(data, 50)

print(f"Recall@1: {recall_1}")
print(f"Recall@5: {recall_5}")
print(f"Recall@10: {recall_10}")
print(f"Recall@50: {recall_50}")

Recall@1: 0.00025
Recall@5: 0.001
Recall@10: 0.00225
Recall@50: 0.00675
