In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset
import torchvision as tv

import os, sys
import numpy as np
import pickle
from PIL import Image
import math
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, cosine_distances

from model_phrase_mlm import TextEncoder, ImageEncoder
from transformers import BertConfig

from misc.config_iu import Config
cfg  = Config()
cfg.GPU_ID = 2

torch.cuda.set_device(cfg.GPU_ID)

In [2]:
MAX_DIM = 2048

# train_transform = tv.transforms.Compose([
#     tv.transforms.RandomRotation(15), # rotation will cost 0.1s for each 10 images
#     tv.transforms.RandomCrop(MAX_DIM, pad_if_needed=True), # 0.6s for each 10 images
#     tv.transforms.ColorJitter(brightness=[0.5, 1.8] # colorjitter will cost 0.32s for each 10 images
#                               , contrast=[0.5, 1.8]
#                               , saturation=[0.5, 1.8]),
#     tv.transforms.ToTensor(), 
#     tv.transforms.Normalize(0.5, 0.5)
# ])

val_transform = tv.transforms.Compose([
    tv.transforms.CenterCrop(MAX_DIM),
    tv.transforms.ToTensor(),  
    tv.transforms.Normalize(0.5, 0.5)
])

class IUDataset(Dataset):
    def __init__(self, root, dataset, max_length, transform=val_transform, mode='test'):
        super().__init__()
        
        self.root = root #save dir
        self.transform = transform
        self.mode = mode

        self.classes = dataset['classes'] # multi-label list
        self.datadict = dataset['data_dict'] # uid: {text:text, filenames:[filename]}
        if self.mode == 'test':
            self.keys = np.concatenate([dataset['data_split']['train_uids'], dataset['data_split']['val_uids'], dataset['data_split']['test_uids']]) # uid list

        
        self.idx2word = dataset['idx2word']
        self.idx2word[8410] = '[MASK]'
        self.word2ids = dataset['word2idx']
        self.word2ids['[MASK]'] = 8410
        self.__sep_id__ = dataset['word2idx']['[SEP]']
        self.vocab_size = len(dataset['word2idx'])
        self.max_length = max_length + 1
        
#         self.err_log = os.path.join(log_dir, 'err.log') # create error log
#         if not os.path.exists(self.err_log):
#             with open(self.err_log, 'w') as f:
#                 f.write('Epoch 0:\n\n')

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        uid = self.keys[idx]
        
        image_id = np.random.choice(self.datadict[uid]['filenames'])# get one file name randomly
        image_path = os.path.join(self.root, image_id) #original used 'jpg', try 'png'

        try:
            with Image.open(image_path) as img: 
                if self.transform:
                    image = self.transform(img)

        except Exception as ex:
#             with open(self.err_log, 'a+') as f:
#                 f.write('%s\nERR_IMG %s\n' % (ex, image_path))
            print(ex)
            print(image_path)
            return None ## return None, collate_fn will ignore this broken sample
        
        max_len_array = np.zeros(self.max_length, dtype='int')
        cap_mask = np.zeros(self.max_length, dtype='int')
        caption = np.array(self.datadict[uid]['token_ids'])
        if len(caption)<=self.max_length:
            cap_mask[:len(caption)] = 1
            max_len_array[:len(caption)] = caption
        else:
            cap_mask[:] = 1
            max_len_array = caption[:self.max_length]
            max_len_array[-1] = self.__sep_id__
        caption = max_len_array
        cap_mask = cap_mask.astype(bool)
        cap_lens = cap_mask.sum(-1)
        return image, caption, cap_mask, uid, cap_lens
    
def build_dataset(mode='test', cfg=None):
    data_dir = cfg.dataset_root
    img_dir = os.path.join(data_dir, 'images', 'images_normalized')
    with open(os.path.join(data_dir, 'cleaned_dataset_v2.pickle'), 'rb') as f:
        dataset = pickle.load(f)
    if mode == 'test':
        data = IUDataset(img_dir, dataset, 
                               max_length=cfg.max_length, mode=mode, 
                               transform=val_transform)
        return data

    else:
        raise NotImplementedError(f"{mode} not supported")
        
## collate_fn for handling None type item due to image corruption ##
## This will return batch size - broken image number ##
def collate_fn_ignore_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [37]:
test_set = build_dataset('test', cfg)
print('Testing set %d is loaded.' % len(test_set))
test_loader = torch.utils.data.DataLoader(
                test_set, batch_size=50, 
                collate_fn=collate_fn_ignore_none, drop_last=False,
                shuffle=True, num_workers=4, pin_memory=False)
print('Vocab size is %d.' % test_set.vocab_size)
print('max length is %d' % test_set.max_length)

Testing set 3666 is loaded.
Vocab size is 8411.
max length is 160


In [23]:
it = iter(test_loader)
img, cap, mask, uid, lens = it.next()

In [42]:
bert_config = BertConfig(vocab_size=test_loader.dataset.vocab_size, hidden_size=512, num_hidden_layers=3,
                    num_attention_heads=8, intermediate_size=2048, hidden_act='gelu',
                    hidden_dropout_prob=cfg.hidden_dropout_prob, attention_probs_dropout_prob=cfg.attention_probs_dropout_prob,
                    max_position_embeddings=512, layer_norm_eps=1e-12,
                    initializer_range=0.02, type_vocab_size=2, pad_token_id=0)

##### change the checkpoint path here #####
cfg.text_encoder_path = '/media/My1TBSSD1/IPMI2021/output/MIMIC_phrase_2021_02_26_23_25_51/Model/text_encoder28.pth'
# cfg.text_encoder_path = '/media/My1TBSSD1/IPMI2021/output/MIMIC_phrase_ft_2021_03_02_02_25_32/Model/text_encoder29.pth'
# ################### encoders ################################# #      
image_encoder = ImageEncoder(output_channels=cfg.hidden_dim)
text_encoder = TextEncoder(bert_config=bert_config, output_channels=cfg.hidden_dim)

if cfg.CUDA:
    text_encoder = text_encoder.cuda()
    image_encoder = image_encoder.cuda()
    
if cfg.text_encoder_path != '':
    img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
    print('Load image encoder checkpoint from:', img_encoder_path)
    state_dict = torch.load(img_encoder_path, map_location='cpu')
    image_encoder.load_state_dict(state_dict['model'])

    text_encoder_path = cfg.text_encoder_path
    print('Load text encoder checkpoint from:', text_encoder_path)
    state_dict = torch.load(text_encoder_path, map_location='cpu')
    text_encoder.load_state_dict(state_dict['model'])

Load image encoder checkpoint from: /media/My1TBSSD1/IPMI2021/output/MIMIC_phrase_2021_02_26_23_25_51/Model/image_encoder28.pth
Load text encoder checkpoint from: /media/My1TBSSD1/IPMI2021/output/MIMIC_phrase_2021_02_26_23_25_51/Model/text_encoder28.pth


In [6]:
@torch.no_grad()
def testing(cnn_model, trx_model, dataloader):
    cnn_model.eval()
    trx_model.eval()
    #####################################
    v_feat = []
    t_feat = []
    uids = []
    val_data_iter = iter(dataloader)
    
    for step in tqdm(range(len(val_data_iter))):
        real_imgs, captions, masks, class_ids, cap_lens = val_data_iter.next()
        if cfg.CUDA:
            real_imgs, captions, masks, cap_lens = real_imgs.cuda(), captions.cuda(), masks.cuda(), cap_lens.cuda()
        v_r, v_g, _, _, _, _ = cnn_model(real_imgs)
        t_w, t_b, t_t, t_g = trx_model(captions, masks)
        v_feat.append(v_g.detach().cpu().numpy())
        t_feat.append(t_g.detach().cpu().numpy())
        uids += class_ids.tolist()
        
    v_feat = np.concatenate(v_feat, axis=0)
    t_feat = np.concatenate(t_feat, axis=0)
    
    return v_feat, t_feat, uids

In [43]:
v_feat, t_feat, uids = testing(image_encoder, text_encoder, test_loader)
print(v_feat.shape, t_feat.shape, len(uids))

100%|██████████| 74/74 [01:11<00:00,  1.04it/s]


(3666, 512) (3666, 512) 3666


In [44]:
## I2T R - 1k here ##
ks = 1000
coss = []
tp1s = []
tp5s = []
tp10s = []
for i in range(3):
    cos = cosine_similarity(v_feat[i*ks:(i+1)*ks], t_feat[i*ks:(i+1)*ks])
    cos = np.array(cos)
    coss.append(cos)
#     print(cos.shape)
    tp = cos.argsort()[:,-10:][:,::-1] # top10
    sn = tp.shape[0]
    gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
    hits = np.equal(tp,gt)
    top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
    top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
    top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]
    tp1s.append(top1)
    tp5s.append(top5)
    tp10s.append(top10)
    
print('%.4f\t%.4f\t%.4f' % (np.mean(tp1s), np.mean(tp5s), np.mean(tp10s)) )

0.0497	0.1310	0.1883


In [45]:
## T2I R - 1k here ##
ks = 1000
coss = []
tp1s = []
tp5s = []
tp10s = []
for i in range(3):
    cos = cosine_similarity(t_feat[i*ks:(i+1)*ks], v_feat[i*ks:(i+1)*ks])
    cos = np.array(cos)
    coss.append(cos)
#     print(cos.shape)
    tp = cos.argsort()[:,-10:][:,::-1] # top10
    sn = tp.shape[0]
    gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
    hits = np.equal(tp,gt)
    top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
    top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
    top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]
    tp1s.append(top1)
    tp5s.append(top5)
    tp10s.append(top10)
    
print('%.4f\t%.4f\t%.4f' % (np.mean(tp1s), np.mean(tp5s), np.mean(tp10s)) )

0.0583	0.1353	0.1873


In [21]:
## I2T R - 1k here ##
ks = 100
coss = []
tp1s = []
tp5s = []
tp10s = []
for i in range(36):
    cos = cosine_similarity(v_feat[i*ks:(i+1)*ks], t_feat[i*ks:(i+1)*ks])
    cos = np.array(cos)
    coss.append(cos)
#     print(cos.shape)
    tp = cos.argsort()[:,-10:][:,::-1] # top10
    sn = tp.shape[0]
    gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
    hits = np.equal(tp,gt)
    top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
    top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
    top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]
    tp1s.append(top1)
    tp5s.append(top5)
    tp10s.append(top10)
    
print('%.4f\t%.4f\t%.4f' % (np.mean(tp1s), np.mean(tp5s), np.mean(tp10s)) )

0.1592	0.3394	0.4367


In [22]:
## T2I R - 1k here ##
ks = 100
coss = []
tp1s = []
tp5s = []
tp10s = []
for i in range(36):
    cos = cosine_similarity(t_feat[i*ks:(i+1)*ks], v_feat[i*ks:(i+1)*ks])
    cos = np.array(cos)
    coss.append(cos)
#     print(cos.shape)
    tp = cos.argsort()[:,-10:][:,::-1] # top10
    sn = tp.shape[0]
    gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
    hits = np.equal(tp,gt)
    top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
    top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
    top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]
    tp1s.append(top1)
    tp5s.append(top5)
    tp10s.append(top10)
    
print('%.4f\t%.4f\t%.4f' % (np.mean(tp1s), np.mean(tp5s), np.mean(tp10s)) )

0.1608	0.3567	0.4717


In [13]:
## I2T R - 3k here ##
ks = 3666
cos = cosine_similarity(v_feat, t_feat)
cos = np.array(cos)
print(cos.shape)
tp = cos.argsort()[:,-10:][:,::-1] # top10
sn = tp.shape[0]
gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
hits = np.equal(tp,gt)
top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]

print('%.4f\t%.4f\t%.4f' % (np.mean(top1), np.mean(top5), np.mean(top10)) )

(3666, 3666)
0.0221	0.0679	0.0957


In [14]:
## T2I R - 3k here ##
ks = 3666
cos = cosine_similarity(t_feat, v_feat)
cos = np.array(cos)
print(cos.shape)
tp = cos.argsort()[:,-10:][:,::-1] # top10
sn = tp.shape[0]
gt = np.repeat(np.arange(sn).reshape(sn,1), 10, axis=1)
hits = np.equal(tp,gt)
top1 = hits[:,:1].any(axis=1).sum() / hits.shape[0]
top5 = hits[:,:5].any(axis=1).sum() / hits.shape[0]
top10 = hits[:,:10].any(axis=1).sum() / hits.shape[0]

print('%.4f\t%.4f\t%.4f' % (np.mean(top1), np.mean(top5), np.mean(top10)) )

(3666, 3666)
0.0278	0.0742	0.1047
