In [1]:
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import pandas as pd 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader

from models.model_retrieval import ALBEF
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer

import utils
from dataset import create_dataset, create_sampler, create_loader
from scheduler import create_scheduler
from optim import create_optimizer

import matplotlib.pyplot as plt 
from Retrieval import itm_eval,evaluation
def img_show(img):    
    img = torch.permute(img,dims=(1,2,0)).detach().numpy()
    img = (img- np.min(img)) / (np.max(img) - np.min(img))
    #img = img[:,:,::-1]
    plt.imshow(img)
    plt.show()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class args:
    output_dir = './output/Retrieval_coco_romixgen_mixup_textconcat_ratio05_4m_fix2/'
    checkpoint = './output/Retrieval_coco_romixgen_mixup_textconcat_ratio05_4m_fix2/checkpoint_2.pth'
    text_encoder = 'bert-base-uncased'
    device = 'cuda:3'
    seed = 42 
    world_size = 1 
    
    
#### main ####
config = yaml.load(open(os.path.join(args.output_dir,'config.yaml')),Loader=yaml.Loader)


seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = args.device

#dataset 

samplers = [None, None, None]
train_dataset, val_dataset, test_dataset = create_dataset('re', config)  
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                          batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
                                                          num_workers=[0,0,0],
                                                          is_trains=[True, False, False], 
                                                          collate_fns=[None,None,None])  

# tokenizer 
tokenizer = BertTokenizer.from_pretrained(args.text_encoder)

# Model 
model = ALBEF(config=config, text_encoder=args.text_encoder, tokenizer=tokenizer)

# Model checkpoint 
checkpoint = torch.load(args.checkpoint, map_location='cpu') 
state_dict = checkpoint['model']
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 
for key in list(state_dict.keys()):
    if 'bert' in key:
        encoder_key = key.replace('bert.','')         
        state_dict[encoder_key] = state_dict[key] 
        del state_dict[key]                
msg = model.load_state_dict(state_dict,strict=False)  

print('load checkpoint from %s'%args.checkpoint)
print(msg)  

model = model.to(device)
model_without_ddp = model
#score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, tokenizer, device, config)
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config)

#val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)  
#print(val_result)
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)    
print(test_result)

epoch = 4 

log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
                **{f'test_{k}': v for k, v in test_result.items()},                  
            'epoch': epoch,
            }
with open(os.path.join(args.output_dir, "log.txt"),"w") as f:
    f.write(json.dumps(log_stats) + "\n")   


In [None]:
import ast 
result = [] 
ratio = [0.01,0.05]
method = ['romixgen','vanila']
for r in ratio:
    for m in method:
        line = open(f'./output/Retrieval_coco_small_{r}_{m}/log.txt').readline()
        line = ast.literal_eval(line)
        result.append({f'{r}_{m}' : line})

for i,res in enumerate(result):
    if i == 0:
        df = pd.DataFrame.from_dict(res,orient='index')
    else:
        df = pd.concat([df,pd.DataFrame.from_dict(res,orient='index')])


# wokring 

In [24]:
import yaml 
from dataset import create_dataset, create_sampler, create_loader
config = yaml.load(open('./configs/Retrieval_coco_temp.yaml'),Loader=yaml.Loader)
config['romixgen']['text']['romixgen_true'] = True
config['romixgen']['text']['method'] = 'txtshuffle'
 
train_dataset, val_dataset, test_dataset = create_dataset('re', config)  
samplers = [None, None, None]
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                          batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
                                                          num_workers=[0,0,0],
                                                          is_trains=[True, False, False], 
                                                          collate_fns=[None,None,None])
img_aug = train_dataset.romixgen.img_aug
data = next(iter(train_loader))
txt = data[1][0]

In [3]:
from dataset.caption_dataset import re_eval_perturb_dataset
import yaml 
config = yaml.load(open('./configs/Retrieval_coco_temp.yaml'),Loader=yaml.Loader)
pertur = None 
test_dataset = re_eval_perturb_dataset(config['test_file'],config['image_res'], config['image_root'],pertur=pertur)

In [27]:
train_loader.dataset.romixgen_ratio = 1 
train_loader.dataset.romixgen.txt_aug.method = 'mixconcat'
a,b,c= next(iter(train_loader))

In [31]:
txt1 = ["I love you so much no way"]
txt2 = ["No I hate you mad no"]

In [32]:
train_dataset.romixgen.txt_aug.mix_concat(txt1,txt2)

'I love you No I hate so much no you mad no way'

In [35]:
train_dataset.romixgen.txt_aug.txt_shuffle(txt1[0],txt2[0])

'I No love I you hate so you much mad no no way'

In [37]:
train_dataset.romixgen.txt_aug.conjunction_concat(txt1,txt2)

'No I hate you mad no furthermore I love you so much no way'

In [40]:
train_dataset.romixgen.txt_aug.replace_word(txt1,txt2)

TypeError: replace_word() missing 1 required positional argument: 'obj_cats'