In [None]:
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import pandas as pd 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader

from models.model_retrieval import ALBEF
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer

import utils
from dataset import create_dataset, create_sampler, create_loader
from scheduler import create_scheduler
from optim import create_optimizer



In [None]:
class args:
    output_dir = './output/Retrieval_coco_small_0.05_vanila/'
    checkpoint = './output/Retrieval_coco_small_0.05_vanila/checkpoint_4.pth'
    text_encoder = 'bert-base-uncased'
    device = 'cuda:1'
    seed = 42 
    world_size = 1 
    
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
    
    #Images->Text 
    ranks = np.zeros(scores_i2t.shape[0])
    for index,score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  
    #Text->Images 
    ranks = np.zeros(scores_t2i.shape[0])
    
    for index,score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)        

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    eval_result =  {'txt_r1': tr1,
                    'txt_r5': tr5,
                    'txt_r10': tr10,
                    'txt_r_mean': tr_mean,
                    'img_r1': ir1,
                    'img_r5': ir5,
                    'img_r10': ir10,
                    'img_r_mean': ir_mean,
                    'r_mean': r_mean}
    return eval_result
    
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
    # test
    model.eval() 
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Evaluation:'    
    
    print('Computing features for evaluation...')
    start_time = time.time()  

    texts = data_loader.dataset.text   
    num_text = len(texts)
    text_bs = 256
    text_feats = []
    text_embeds = []  
    text_atts = []
    for i in range(0, num_text, text_bs):
        text = texts[i: min(num_text, i+text_bs)]
        text_input = tokenizer(text, padding='max_length', truncation=True, max_length=30, return_tensors="pt").to(device) 
        with torch.no_grad():
            text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')  
        text_feat = text_output.last_hidden_state
        text_embed = F.normalize(model.text_proj(text_feat[:,0,:]))
        
        text_embeds.append(text_embed.detach().cpu())   
        text_feats.append(text_feat.detach().cpu())
        text_atts.append(text_input.attention_mask)
        
    text_embeds = torch.cat(text_embeds,dim=0)
    text_feats = torch.cat(text_feats,dim=0)
    text_atts = torch.cat(text_atts,dim=0)
    
    image_feats = []
    image_embeds = []
    for image, img_id in data_loader: 
        image = image.to(device) 
        with torch.no_grad():
            image_feat = model.visual_encoder(image)        
            image_embed = model.vision_proj(image_feat[:,0,:])            
            image_embed = F.normalize(image_embed,dim=-1)      
        
        image_feats.append(image_feat.detach().cpu())
        image_embeds.append(image_embed.detach().cpu())
        
        
    image_feats = torch.cat(image_feats,dim=0)
    image_embeds = torch.cat(image_embeds,dim=0)
    
    sims_matrix = image_embeds @ text_embeds.t()
    score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0)
    
    num_tasks = utils.get_world_size()
    rank = utils.get_rank() 
    step = sims_matrix.size(0)//num_tasks + 1
    start = rank*step
    end = min(sims_matrix.size(0),start+step)

    for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 
        
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
        encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        
        with torch.no_grad():
            output = model.text_encoder(encoder_embeds         = text_feats[topk_idx].to(device), 
                                        attention_mask         = text_atts[topk_idx].to(device),
                                        encoder_hidden_states  = encoder_output.to(device),
                                        encoder_attention_mask = encoder_att.to(device),                             
                                        return_dict            = True,
                                        mode                   = 'fusion'
                                        )
            
        score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
        score_matrix_i2t[start+i,topk_idx] = score.detach().cpu()
        
    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0)
    
    step = sims_matrix.size(0)//num_tasks + 1
    start = rank*step
    end = min(sims_matrix.size(0),start+step)    
    
    for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 
        
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
        encoder_output     = image_feats[topk_idx].to(device)
        encoder_att        = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        
        with torch.no_grad():
            output = model.text_encoder(encoder_embeds         = text_feats[start+i].repeat(config['k_test'],1,1).to(device), 
                                        attention_mask         = text_atts[start+i].repeat(config['k_test'],1).to(device),
                                        encoder_hidden_states  = encoder_output.to(device),
                                        encoder_attention_mask = encoder_att.to(device),                             
                                        return_dict            = True,
                                        mode                   = 'fusion'
                                        )
        score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
        score_matrix_t2i[start+i,topk_idx] = score.detach().cpu()       
        
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Evaluation time {}'.format(total_time_str)) 

    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()

In [None]:
#### main ####
config = yaml.load(open(os.path.join(args.output_dir,'config.yaml')),Loader=yaml.Loader)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = args.device

#dataset 

samplers = [None, None, None]
train_dataset, val_dataset, test_dataset = create_dataset('re', config)  
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                          batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
                                                          num_workers=[0,0,0],
                                                          is_trains=[True, False, False], 
                                                          collate_fns=[None,None,None])  

# tokenizer 
tokenizer = BertTokenizer.from_pretrained(args.text_encoder)

# Model 
model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer)

# Model checkpoint 
checkpoint = torch.load(args.checkpoint, map_location='cpu') 
state_dict = checkpoint['model']
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 
for key in list(state_dict.keys()):
    if 'bert' in key:
        encoder_key = key.replace('bert.','')         
        state_dict[encoder_key] = state_dict[key] 
        del state_dict[key]                
msg = model.load_state_dict(state_dict,strict=False)  

print('load checkpoint from %s'%args.checkpoint)
print(msg)  

model = model.to(device)
model_without_ddp = model
score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, tokenizer, device, config)
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config)

In [9]:
val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)  
print(val_result)
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)    
print(test_result)

epoch = 'eval'

log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
                **{f'test_{k}': v for k, v in test_result.items()},                  
            'epoch': epoch,
            }
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
    f.write(json.dumps(log_stats) + "\n")   


{'txt_r1': 90.633608815427, 'txt_r5': 99.72451790633609, 'txt_r10': 100.0, 'txt_r_mean': 96.78604224058769, 'img_r1': 82.43392070484582, 'img_r5': 98.29295154185021, 'img_r10': 99.39427312775331, 'img_r_mean': 93.37371512481644, 'r_mean': 95.07987868270206}
{'txt_r1': 90.633608815427, 'txt_r5': 99.72451790633609, 'txt_r10': 100.0, 'txt_r_mean': 96.78604224058769, 'img_r1': 82.43392070484582, 'img_r5': 98.29295154185021, 'img_r10': 99.39427312775331, 'img_r_mean': 93.37371512481644, 'r_mean': 95.07987868270206}
