In [1]:
import os
import logging
import torch
import torch.utils.data as data
from utils.utils import compute_topk, collate_fn
from data.dataloader import ITM_Dataset
import yaml
from attrdict import AttrDict
from models.model_zoo import ViT_NumScore
from tqdm import tqdm

In [2]:
if __name__ == '__main__':
    with open('configs.yaml') as f:
        test_args = yaml.load(f, Loader=yaml.FullLoader)['test']
        args = AttrDict(test_args)

In [3]:
# prepare dataset
test_dataset = ITM_Dataset(args.image_root_path,
                            args.sentence_file_path,
                            'test',
                            args.max_length)
test_loader = data.DataLoader(test_dataset, 
                               args.batch_size, 
                               collate_fn=lambda b: collate_fn(b, args.max_length),
                               shuffle=False, 
                               num_workers=8,
                               pin_memory=True,
                               drop_last=True)
print('Data loaded')

ac_i2t_top1_best = 0.0
ac_i2t_top10_best = 0.0
ac_t2i_top1_best = 0.0
ac_t2i_top10_best = 0.0
i2t_model = '1.pth.tar'
model_file = os.path.join(args.checkpoint_dir, '1.pth.tar')
epoch = i2t_model.split('.')[0]
network = ViT_NumScore()
network = network.cuda()
network_dict = network.state_dict()
pretrained_dict = torch.load(model_file)['network']
# process keyword of pretrained model
prefix = 'module.image_model.'
pretrained_dict = {prefix + k[:] :v for k,v in pretrained_dict.items()}
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict}
network_dict.update(pretrained_dict)
network.load_state_dict(network_dict)

Data loaded


Some weights of the model checkpoint at /home/giang/.cache/torch/sentence_transformers/sbert.net_models_bert-base-nli-stsb-mean-tokens/0_BERT were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [4]:
def topk(sim, labels, k=[1, 10]):
    result = []
    maxk = max(k)
    size_total = len(labels)
    _, pred_index = sim.topk(maxk, 0, True, True)
    pred_labels = labels[pred_index]
    print(pred_labels)
    correct = pred_labels.eq(labels.view(1,-1).expand_as(pred_labels))
    for topk in k:
        correct_k = torch.sum(correct[:topk], dim=0)
        correct_k = torch.sum(correct_k > 0).float()
        result.append(correct_k * 100 / size_total)
    return result
# switch to evaluate mode
network.eval()
images_bank = []
text_bank = []
labels_bank = []
index = 0
with torch.no_grad():
    for images, input_ids, token_type_ids, attention_masks, labels in test_loader:
        images = images.cuda()
        labels = labels.cuda()
        input_ids = input_ids.cuda()
        token_type_ids = token_type_ids.cuda()
        attention_masks = attention_masks.cuda()

        interval = images.shape[0]
        image_embeddings, text_embeddings = network(images, input_ids, token_type_ids, attention_masks, val=True)
        images_bank.append(image_embeddings)
        text_bank.append(text_embeddings)
        labels_bank.append(labels)

        index = index + interval


    images_bank = torch.cat(images_bank[:index], dim=0)
    text_bank = torch.cat(text_bank[:index], dim=0)
    labels_bank = torch.cat(labels_bank[:index], dim=0)

    images_bank = images_bank[:100]
    text_bank = text_bank[:100]
    labels_bank = labels_bank[:100]

    scoring_i2t, scoring_t2i = network.scoring_i2t, network.scoring_t2i
    images_embeddings = images_bank
    text_embeddings = text_bank
    labels = labels_bank

    images_embeddings_norm = images_embeddings/images_embeddings.norm(dim=2)[:, :, None]
    text_embeddings_norm = text_embeddings/text_embeddings.norm(dim=1)[:, None]
    batch_size = images_embeddings.shape[0]
    i2t = []
    t2i = []
    for i in tqdm(range(batch_size)):
        item_i2t = torch.matmul(images_embeddings[i, :, :].unsqueeze(0), text_embeddings_norm.transpose(0, 1))
        item_t2i = torch.matmul(images_embeddings_norm[i, :, :].unsqueeze(0), text_embeddings.transpose(0, 1))

        item_i2t, item_t2i = item_i2t.transpose(1, 2), item_t2i.transpose(1, 2)
        item_i2t = scoring_i2t(item_i2t).squeeze().unsqueeze(0)
        item_t2i = scoring_t2i(item_t2i).squeeze(-1)

        i2t.append(item_i2t)
        t2i.append(item_t2i)
    i2t = torch.cat(i2t, dim=0)
    t2i = torch.cat(t2i, dim=0)
    t2i = t2i.transpose(0, 1)

    result = []
    result.extend(topk(i2t, labels, k=[1, 10]))
    result.extend(topk(t2i, labels, k=[1, 10]))


100%|██████████| 100/100 [00:00<00:00, 6463.81it/s]

tensor([[137, 304,  51, 137, 137,  97,  97, 318,  97,  97, 124, 124, 124, 124,
         124, 134,  34,  34, 134, 279, 279, 279, 395,  25, 395, 137, 304, 304,
         304, 304, 279, 279, 246, 279, 279, 304, 134, 134, 134, 259,  51,  25,
         279, 279, 395, 137, 304,  51, 259,  25, 318, 219,  25, 124, 318, 137,
         134, 137, 137, 137, 246,  25, 137,  51,  25,  89, 246, 279, 246, 279,
         395, 395,  25, 199, 134,  89,  34, 124,  89,  89,  97, 246,  97,  97,
          97,  51, 246, 124, 395, 246,  25,  25,  25, 137,  34, 124, 304, 134,
         304, 304],
        [137, 304,  51, 137, 137,  97,  97, 318,  97,  97, 124, 124, 124, 124,
         124, 134,  34,  34, 134, 279, 279, 279, 395,  25, 395, 137, 304, 304,
         304, 304, 279, 279, 246, 279, 279, 304, 134, 134, 134, 259,  51,  25,
         279, 279, 395, 137, 304,  51, 259,  25, 318, 219,  25, 124, 318, 137,
         134, 137, 137, 137, 246,  25, 137,  51,  25,  89, 246, 279, 246, 279,
         395, 395,  25, 199, 134


