In [34]:
import os
import sys
import json
import faiss
import copy
import yaml
import torch
import numpy as np
from torch.utils.data import DataLoader
from lightning.pytorch import seed_everything
from tqdm import tqdm

sys.path.append('/local/vondrick/nd2794/CoIR')
from src.datasets.lasco_datasets_inbatch import lasco_dataset_inbatch
from src.datamodules.lasco_data_module_inbatch import LASCODataModuleINBATCH

In [2]:
config  = copy.deepcopy(yaml.safe_load(open('/local/vondrick/nd2794/CoIR/configs/config_inbatch.yaml', 'r')))

In [3]:
datamodule = LASCODataModuleINBATCH(config)

In [4]:
datamodule.setup('fit')



In [None]:
dataset = lasco_dataset_inbatch(config, 'val')

corpus_dataloader =  DataLoader(
    dataset=dataset, 
    batch_size=2, 
    shuffle=config['dataloader']['shuffle'],
    num_workers=config['dataloader']['num_workers'], 
    collate_fn=dataset.collate_fn,
    pin_memory=config['dataloader']['pin_memory'],
    drop_last=config['dataloader']['drop_last'],
    persistent_workers=config['dataloader']['persistent_workers']
)

In [None]:
%%time
for batch_idx, batch in enumerate(corpus_dataloader):
    print(batch['query-image-id'])
    print(batch['query-image']['pixel_values'].shape)
    print(batch['target-image-id'])
    print(batch['target-image']['pixel_values'].shape)
    print(batch['query-text']['input_ids'].shape)
    break

In [5]:
train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

In [11]:
%%time
for batch_idx, batch in enumerate(val_dataloader ):
    print(batch['query-image-id'])
    print(batch['query-image']['pixel_values'].shape)
    print(batch['target-image-id'])
    print(batch['target-image']['pixel_values'].shape)
    print(batch['query-text']['input_ids'].shape)
    print(batch['query-text']['attention_mask'].shape)
    break

[318114, 28714, 546721, 245874, 245874, 436127, 271639, 271639, 37846, 28714, 136795, 136795, 136795, 37846, 505144, 505144, 252294, 252294, 420339, 547783, 547783, 547783, 82933, 424333, 424333, 469174, 503600, 77628, 77628, 77628]
torch.Size([30, 3, 224, 224])
[306889, 426166, 513115, 515660, 50829, 208549, 283217, 201561, 412966, 277005, 430469, 479448, 508443, 228135, 485799, 485799, 311374, 405183, 412975, 130527, 443583, 382399, 526044, 432647, 224724, 493102, 404145, 562356, 222370, 562356]
torch.Size([30, 3, 224, 224])
torch.Size([30, 12])
torch.Size([30, 12])
CPU times: user 15.2 ms, sys: 4.93 ms, total: 20.1 ms
Wall time: 1.73 s


## Load the model

In [13]:
from src.models.clip.clip_inbatch import CLIPModelINBATCH

In [17]:
model = CLIPModelINBATCH(config)

In [18]:
model
print(" ")

 


In [19]:
%%time
for batch_idx, batch in enumerate(train_dataloader):
    #batch['query-image']['pixel_values'].to('cuda:0')
    #batch['target-image']['pixel_values'].to('cuda:0')
    #batch['query-text']['input_ids'].to('cuda:0')
    #batch['query-text']['attention_mask'].to('cuda:0')
    
    outs = model.forward(batch)
    
    break

CPU times: user 32.8 s, sys: 8.9 s, total: 41.7 s
Wall time: 2.41 s


In [20]:
outs

{'query_image_embeds': tensor([[-5.2605e-03,  3.9427e-02,  3.0648e-02,  ...,  7.0527e-02,
          -3.1308e-02, -1.2429e-02],
         [-1.4792e-02, -3.0441e-02, -2.0111e-02,  ...,  3.2865e-02,
          -1.3410e-02, -9.9991e-03],
         [-1.9430e-02,  2.5124e-03, -1.9967e-02,  ...,  6.1337e-02,
          -4.4695e-05, -7.4703e-03],
         ...,
         [ 5.5767e-02, -1.0665e-02, -3.8368e-02,  ...,  2.4470e-02,
           3.4702e-03,  5.1626e-02],
         [-1.4817e-02,  7.0236e-03, -1.4444e-02,  ...,  1.2696e-01,
           9.4681e-03, -1.8979e-02],
         [-2.0481e-02,  6.2025e-03, -2.0419e-02,  ...,  6.9607e-02,
          -1.1671e-02, -6.7372e-03]], grad_fn=<DivBackward0>),
 'target_image_embeds': tensor([[-0.0017,  0.0273,  0.0187,  ...,  0.1060,  0.0158,  0.0092],
         [-0.0176, -0.0148, -0.0185,  ...,  0.0122, -0.0028,  0.0291],
         [ 0.0417,  0.0276,  0.0187,  ...,  0.0628,  0.0407, -0.0021],
         ...,
         [-0.0421,  0.0349, -0.0252,  ...,  0.0656, -0.012

In [25]:
query_image_embeds = outs['query_image_embeds']
query_image_embeds.shape

torch.Size([30, 512])

In [26]:
target_image_embeds = outs['target_image_embeds']
target_image_embeds.shape

torch.Size([30, 512])

In [27]:
query_text_embeds = outs['query_text_embeds']
query_text_embeds.shape

torch.Size([30, 512])

In [29]:
target_hat_embeds = query_image_embeds + query_text_embeds
target_hat_embeds = target_hat_embeds / torch.linalg.vector_norm(target_hat_embeds, ord=2, dim=1, keepdim=True)
target_hat_embeds.shape

torch.Size([30, 512])

In [36]:
logits_per_A = torch.mm(target_hat_embeds, target_image_embeds.t())
logits_per_B = torch.mm(target_image_embeds, target_hat_embeds.t())

In [39]:
labels = torch.arange(logits_per_A.size(0), device=logits_per_A.device)

In [41]:
loss_A = torch.nn.functional.cross_entropy(logits_per_A, labels)
loss_B = torch.nn.functional.cross_entropy(logits_per_B, labels)

In [46]:
loss = (loss_A + loss_B) / 2.0
loss

tensor(3.2303, grad_fn=<DivBackward0>)

In [40]:
labels

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [31]:
dotp_mat = torch.mm(target_hat_embeds, target_image_embeds.t())
dotp_mat.shape

torch.Size([30, 30])

In [35]:
np.arange(30)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

In [32]:
dotp_mat

tensor([[0.6270, 0.4415, 0.4345, 0.4391, 0.4124, 0.2560, 0.3535, 0.4331, 0.3560,
         0.2920, 0.4477, 0.4352, 0.3522, 0.4555, 0.4340, 0.4047, 0.4580, 0.4377,
         0.4331, 0.3976, 0.3747, 0.4091, 0.4103, 0.4692, 0.3665, 0.3629, 0.4012,
         0.4108, 0.3690, 0.4264],
        [0.4148, 0.7551, 0.3139, 0.3726, 0.4541, 0.3261, 0.4757, 0.4138, 0.3356,
         0.3092, 0.3997, 0.4046, 0.3372, 0.3980, 0.3917, 0.4206, 0.4235, 0.4062,
         0.3592, 0.4404, 0.3611, 0.3973, 0.3698, 0.4672, 0.3724, 0.3852, 0.4027,
         0.4159, 0.3743, 0.3699],
        [0.4894, 0.4009, 0.5321, 0.4418, 0.4498, 0.2858, 0.4127, 0.4471, 0.3860,
         0.3197, 0.4270, 0.4535, 0.3594, 0.4670, 0.4797, 0.4187, 0.4538, 0.5104,
         0.4502, 0.4521, 0.3450, 0.4125, 0.3811, 0.4937, 0.4038, 0.3660, 0.5092,
         0.4199, 0.3792, 0.4299],
        [0.5124, 0.4419, 0.4837, 0.5166, 0.4273, 0.2824, 0.4054, 0.4379, 0.4025,
         0.3119, 0.5026, 0.5525, 0.3930, 0.5310, 0.4520, 0.4316, 0.5034, 0.4540,
       