In [None]:
import os
import warnings
import torchvision.datasets as dset

from transformers import logging
from torch.utils.data import DataLoader

from encoder_models import *
from utils_clip import *

from tqdm import tqdm

logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# Inits

In [None]:
config = load_config()

checkpoint_path = config['paths']['checkpoint']
vit_trans_name = 'google/vit-base-patch16-224'
bert_model_name = 'bert-base-uncased'
embed_dim = 128 ## make this dynamic

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print (f'Using: {device}')

coco_dataset_tst = dset.CocoCaptions(
    root=config['data']['test_root'],
    annFile=config['data']['test_ann']
)

test_dataset = CocoCaptionDataset(coco_dataset_tst, mode="test")

# instantiate
encoder_1 = Transformer_One(vit_trans_name, embed_dim, device=device)
encoder_2 = Transformer_Two(bert_model_name, embed_dim, device=device)
# load
encoder_1, encoder_2, _, _, loss_state_dct, epoch_load, tr_loss, vld_loss = load_model(checkpoint_path, encoder_1, encoder_2, device)
logit_sc = loss_state_dct['logit_scale'].exp() ## logit scale
# place
encoder_1.to(device)
encoder_2.to(device)

dataloader = DataLoader(test_dataset, batch_size=1000, shuffle=False, drop_last=False, num_workers=16)

# Precompute Image embeddings

In [None]:
save_path = config['paths']['img_embedd']
pt_files = [f for f in os.listdir(save_path) if f.endswith('.pt')]

if not pt_files:
    precompute_img_emb(encoder_1, dataloader, save_path)

im_em = load_embeddings(save_path)
assert im_em.shape[0] == len(test_dataset), "Dimension mismatch!"
print (im_em.shape)

# Precompute text embeddings (classes)

In [None]:
path = config['paths']['coco_classes']
class_embed, classes = precompute_class_emb(encoder_2, path)

# Choose 5 images in random and plot Zero-Shot class predictions! 

In [None]:
# take 5 random pictures and their embs
random_images = torch.randint(0, len(test_dataset), (5,))
query_images = im_em[random_images]

im_plt = []
txt_plt = []
probs = []

for i in range(len(query_images)):
    top_idxs, sims = find_top_k_matches(class_embed, query_images[i].unsqueeze(0), logit_sc, k=5, device=device)
    this_image, _ = coco_dataset_tst[random_images[i]]
    relevant_txt = [classes[idx.item()] for idx in top_idxs]
    im_plt.append(this_image)
    txt_plt.append(relevant_txt)
    probs.append(sims)

plot_images_and_bars(im_plt, txt_plt, probs)