In [14]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from logging import log
import re
import numpy as np
import os.path as op
from hashing_module.triplet_loss import *
from torch.autograd.grad_mode import F

from torch.nn.modules import loss
from torch.utils.data.sampler import Sampler

import argparse
from oscar.modeling.modeling_bert import HashingformerALL,normal_label
from pytorch_transformers import BertTokenizer, BertConfig
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from torch.autograd import Variable
from oscar.utils.tsv_file import TSVFile
from torch.nn import CrossEntropyLoss
import json
import base64
import random
from pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule
from hashing_module.utils import calc_map_k
from oscar.utils.logger import setup_logger
from oscar.utils.misc import mkdir
from torch.nn import functional as F

In [15]:
class Opt():
    def __init__(self) -> None:
        self.use_gpu = True
        self.training_size = 10000
        self.query_size = 2000
        self.bit = 64
        self.database_size = 18000 
        self.gamma = 1
        self.eta = 1
        self.valid = True
        self.batch_size = 64
        self.margin = 0.4
        self.gamma = 1
        self.beta = 1
        self.alpha = 1
opt = Opt()

In [16]:
def set_random_seed(seed):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_random_seed(1000)
class RetrievalDataset(Dataset):
    """ Image/Text Retrieval Dataset"""
    def __init__(self, args,tokenizer,split="train"):
        """
        tokenizer: tokenizer to process caption text.
        args: configureation parameters including max_seq_length, etc.
        split: used to infer the data used for training or testing. 
             All files are in .pt format of a dictionary with image keys and 
             image features (pytorch tensors), captions (list of str, support multiple
             captions per image), labels (list of dictionary or str of all labels),

        """
        super(RetrievalDataset, self).__init__()
        with open(args.tagslabel ,"r") as f:
            self.tagslabel = json.load(f)   
        self.args = args  
        self.split = split      
        self.img_file = args.img_feat_file
        self.img_tsv = TSVFile(self.img_file)
        self.img_keys = list(self.tagslabel.keys())  # img_id as int
        imgid2idx_file = op.join(op.dirname(self.img_file), 'imageid2idx.json')
        self.image_id2idx = json.load(open(imgid2idx_file))  # img_id as string
        with open(args.class_name,"r") as f:
            self.class_name = json.load(f)
            self.class_name = np.array(self.class_name)
        if(args.split_keys):
            with open(args.split_keys,"r") as f:
                self.img_keys = json.load(f)
                self.img_keys = [str(i) for i in self.img_keys]
        else:
            random.seed(279834)
            random.shuffle(self.img_keys)
        if(split=="train"):
            self.img_keys = self.img_keys[args.query_size:args.training_size + args.query_size]
        elif(split=="query"):
            self.img_keys = self.img_keys[:args.query_size]
        else:
            self.img_keys = self.img_keys[args.query_size:args.database_size+ args.query_size]
        label_data_dir = op.dirname(self.img_file)
        label_file = os.path.join(label_data_dir, "label.tsv")
        self.label_tsv = TSVFile(label_file)
        self.labels = {}
        for line_no in tqdm(range(self.label_tsv.num_rows())):
            row = self.label_tsv.seek(line_no)
            image_id = row[0]
            if image_id in self.img_keys:
                results = json.loads(row[1])
                objects = results['objects'] if type(
                    results) == dict else results
                self.labels[image_id] = {
                    "image_h": results["image_h"] if type(
                        results) == dict else 600,
                    "image_w": results["image_w"] if type(
                        results) == dict else 800,
                    "class": [cur_d['class'] for cur_d in objects],
                    "boxes": np.array([cur_d['rect'] for cur_d in objects],
                                        dtype=np.float32)
                }
        self.label_tsv._fp.close()
        self.label_tsv._fp = None   
        self.output_mode = 'classification'
        self.tokenizer = tokenizer
        self.max_seq_length = 35
        self.max_img_seq_len = 70
        self.args.max_label_length = args.max_label_length
    def get_od_labels(self, img_key):

        if type(self.labels[img_key]) == str:
            od_labels = self.labels[img_key]
        else:
            od_labels = ' '.join(self.labels[img_key]['class'])
        return od_labels
    def class_tokenize(self,labels,max_length=15):
        all_size = labels.shape[0]
        final_label = []
        for i in range(all_size):
            this_label = torch.zeros((max_length+2))
            class_name = self.class_name[labels[i]>0]
            
            tokens = self.tokenizer.tokenize("".join(class_name))
            tokens = [self.tokenizer.cls_token] + tokens[0:max_length] + [self.tokenizer.sep_token]
            input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            
            this_label[0:len(input_ids)] = torch.Tensor(input_ids)
            final_label.append(this_label)
        final_label = torch.stack(final_label).long()  
        return final_label  

    # def class_tokenize_list(self,labels,max_length=10):
    #     this_label = torch.zeros((max_length+2))
    #     class_name = self.class_name[labels>0]
    #     tokens = [self.tokenizer.cls_token] + class_name.tolist()[0:max_length] + [self.tokenizer.sep_token]
    #     input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
    #     this_label[0:len(input_ids)] = torch.Tensor(input_ids)

    #     return this_label    
    def tensorize_example(self, text_a, img_feat, text_b=None, 
            cls_token_segment_id=0, pad_token_segment_id=0,
            sequence_a_segment_id=0, sequence_b_segment_id=1):
        tokens_a = self.tokenizer.tokenize(text_a)
        if len(tokens_a) > self.max_seq_length - 2:
            tokens_a = tokens_a[:(self.max_seq_length - 2)]

        tokens = [self.tokenizer.cls_token] + tokens_a + [self.tokenizer.sep_token]
        segment_ids = [cls_token_segment_id] + [sequence_a_segment_id] * (len(tokens_a) + 1)
    
        if text_b:
            tokens_b = self.tokenizer.tokenize(text_b)
            if len(tokens_b) > self.max_seq_length   - 2:#a
                tokens_b = tokens_b[: (self.max_seq_length  - 2)]
            tokens_b = [self.tokenizer.cls_token] +tokens_b+ [self.tokenizer.sep_token]
            segment_ids_b = [sequence_b_segment_id] + [sequence_b_segment_id] * (len(tokens_b) -1)
        #这儿分a padding
        seq_len_a = len(tokens)
        seq_padding_len_a = self.max_seq_length - seq_len_a
        tokens += [self.tokenizer.pad_token] * seq_padding_len_a
        segment_ids += [pad_token_segment_id] * seq_padding_len_a
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        #b padding
        seq_len_b = len(tokens_b)
        seq_padding_len_b = self.max_seq_length - seq_len_b
        tokens_b += [self.tokenizer.pad_token] * seq_padding_len_b
        segment_ids_b += [pad_token_segment_id] * seq_padding_len_b
        input_ids_b = self.tokenizer.convert_tokens_to_ids(tokens_b)
        #合并
        input_ids = input_ids+input_ids_b
        segment_ids = segment_ids+segment_ids_b
        # image features
        img_len = img_feat.shape[0]

        if img_len > self.max_img_seq_len:
            img_feat = img_feat[0 : self.max_img_seq_len, :]
            img_len = img_feat.shape[0]
            img_padding_len = 0
        else:
            img_padding_len = self.max_img_seq_len - img_len
            padding_matrix = torch.zeros((img_padding_len, img_feat.shape[1]))
            
            img_feat = torch.cat((img_feat, padding_matrix), 0)

        # generate attention_mask
        att_mask_type = "CLR"
        if att_mask_type == "CLR":
            attention_mask = [1] * seq_len_a + [0] * seq_padding_len_a +[1] * seq_len_b + [0] * seq_padding_len_b +  [1] * img_len + [0] * img_padding_len 


        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        return (input_ids, attention_mask, segment_ids, img_feat)

    def __getitem__(self, index):
        
        img_key = self.img_keys[index]
      
        feature = self.get_image(img_key)
        tag_list = self.tagslabel[img_key]["tags"]
        if(isinstance(tag_list,list)):
            caption = ""
            for i in tag_list:
                caption+=i+" "
            caption=caption.strip()
        else:#is a string
            caption  = tag_list
        od_labels = self.get_od_labels(img_key)
        example = self.tensorize_example(caption, feature, text_b=od_labels)
        label = self.tagslabel[img_key]["label"]
        label=torch.tensor(label, dtype=torch.long)
        if(self.split=="train"):
            raw_label = self.generate_samples(label,self.args.negative_number)
            #negative_label = normal_label(raw_label,max_length = self.args.max_label_length )
            negative_label = self.class_tokenize(raw_label,max_length = self.args.max_label_length )
            return tuple(list(example) + [label,negative_label,raw_label]),img_key
        else:
            return tuple(list(example) + [label]),img_key

    def generate_samples(self, label,negative_number = 99):
        mask = 1-label #
        negative_samples = []
        positive_sample = label 
        while(len(negative_samples)<negative_number):
            smaples=  torch.from_numpy(np.random.choice(2, self.args.class_number,p=[1-5/self.args.class_number,5/self.args.class_number]))
            is_positive = (smaples*label).sum()>0
            if(not is_positive):
                negative_samples.append(smaples)
        final_sample = [positive_sample]+negative_samples
        final_sample = torch.stack(final_sample)
        return final_sample
    def get_image(self, image_id):
        image_idx = self.image_id2idx[str(image_id)]
        row = self.img_tsv.seek(image_idx)
        num_boxes = int(row[1])
        features = np.frombuffer(base64.b64decode(row[-1]),
                                 dtype=np.float32).reshape((num_boxes, -1)).copy()
        t_features = torch.from_numpy(features)
        return t_features

    def __len__(self):
        return len(self.img_keys) 


In [17]:
def calc_neighbor(label1, label2):
    # calculate the similar matrix
    label1=label1.to(torch.float32)
    label2=label2.to(torch.float32)
    Sim = (label1.matmul(label2.transpose(0, 1)) > 0)
    return Sim
def save_checkpoint(model, tokenizer, args, epoch):
    checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}'.format(
        epoch))
    mkdir(checkpoint_dir)
    #model_to_save = model.module if hasattr(model, 'module') else model
    save_num = 0
    while (save_num < 10):
        try:
            
            save_pretrained(model=model,save_directory=checkpoint_dir)
            torch.save(args, op.join(checkpoint_dir, 'training_args.bin'))
            tokenizer.save_pretrained(checkpoint_dir)
            logger.info("Save checkpoint to {}".format(checkpoint_dir))
            break
        except:
            save_num += 1
    if save_num == 10:
        logger.info("Failed to save checkpoint after 10 trails.")
    return
def log_loss_func(a,b,a_l,b_l):
    logit_it = torch.matmul(a,b.t())
    sim_it = torch.matmul(a_l,b_l.t())>0
    theta_it = 1/2*logit_it
    loss = -torch.mean((sim_it * theta_it - torch.log(1.0 + torch.exp(theta_it))))
    precision = f1_calc(logit_it,sim_it,0)
    return loss,precision

In [18]:
def generate_code(model, query_dataloader,TorI="I"):

    class_logits = []
    labels = []
    hashing_bit = []
    for batch,keys in tqdm(query_dataloader):
        #image = X[ind]#.unsqueeze(1).unsqueeze(-1).type(torch.float)
        train_input_ids_this = batch[0].long().cuda()
        train_attention_mask_this = batch[1].long().cuda()
        train_token_type_ids_this = batch[2].long().cuda()
        train_img_feats_this = batch[3].cuda()
        label = batch[4].cuda()
        with torch.no_grad():
            if(TorI=="T"):
                cur_f= model(input_ids=train_input_ids_this,token_type_ids=train_token_type_ids_this,
                                        attention_mask=train_attention_mask_this,img_feats=train_img_feats_this,modal="t")
            else:
                cur_f= model(input_ids=train_input_ids_this,token_type_ids=train_token_type_ids_this,
                                        attention_mask=train_attention_mask_this,img_feats=train_img_feats_this,modal="i")
            hashing_bit.append(cur_f)
            labels.append(label)
    hashing_bit = torch.cat(hashing_bit,0)
    labels = torch.cat(labels,0)

    #B = torch.sign(B)
    return hashing_bit,labels

In [19]:
parser = argparse.ArgumentParser()
parser.add_argument("--num_train_epochs", default=1000, type=int, 
                    help="Total number of training epochs to perform.")
parser.add_argument("--class_name", default='/raid/data_modal/MIR_Flickr_25k/class_name.json"', type=str, required=False,
                    help="The input data dir with all required files.")
parser.add_argument("--split_keys", default='', type=str, required=False,
                    help="split_keys")
parser.add_argument("--tagslabel", default='MIR_Flickr_25k/img_tagslabel.json', type=str, required=False,
                    help="The input data dir with all required files.")
parser.add_argument("--img_feat_file", default='/MIR_Flickr_25k/vinvl_data/vinvl_vg_x152c4/predictions.tsv', type=str, required=False,
                    help="The absolute address of the image feature file.")
parser.add_argument("--output_dir", default='output/log_aipr', type=str, required=False,
                    help="The output directory to save checkpoint and test results.")
parser.add_argument("--num_workers", default=4, type=int, help="Workers in dataloader.")
parser.add_argument("--eval_model_dir", type=str, default='', 
                    help="Model directory for evaluation.")   
parser.add_argument("--do_lower_case", action='store_true', 
                    help="Set this flag if you are using an uncased model.")   
parser.add_argument("--output_file", type=str, default='', 
                    help="Model directory for evaluation.")  
parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial lr.")
parser.add_argument("--weight_decay", default=0.05, type=float, help="Weight deay.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup.")
parser.add_argument("--scheduler", default='linear', type=str, help="constant or linear.")
parser.add_argument("--bit", default=64, type=int, help="constant or linear.")
parser.add_argument("--class_number", default=255, type=int, help="constant or linear.")  
parser.add_argument("--training_size", default=10000, type=int, help="constant or linear.") 
parser.add_argument("--query_size", default=2000, type=int, help="constant or linear.") 
parser.add_argument("--database_size", default=18000, type=int, help="constant or linear.")
parser.add_argument("--no_pretrain", action='store_true', help="constant or linear.")
parser.add_argument("--negative_number", default=9, type=int, help="constant or linear.")
parser.add_argument("--max_label_length", default=10, type=int, help="constant or linear.")
# lists =[
#                 "--tagslabel",
#                 "/raid/data_modal/NUS-WIDE/hashing/img_tagslabel.json",
#                 "--img_feat_file",
#                 "/raid/data_modal/NUS-WIDE/hashing/use_for_vinvl/feature.tsv",
#                 "--do_lower_case",
#                 "--output_dir",
#                 "output/try_contrastive_COCO",
#                 "--eval_model_dir",
#                 "/raid/whf/ObjectContrastiveTransformer/output/try_contrastive_NUS_16/checkpoint-10",
#                 "--training_size","10500",
#                 "--query_size","2100",
#                 "--database_size","188321",
#                 "--class_number","21",
#                 "--class_name","/raid/data_modal/NUS-WIDE/hashing/class_name.json",
#                 "--bit","16"
                
#             ]

# args = parser.parse_args(lists)
# args.output = "output/try_contrastive_NUS_16"


lists =[
                "--tagslabel",
                "/raid/data_modal/IAPR_TC-12/img_tagslabel.json",
                "--img_feat_file",
                "/raid/data_modal/IAPR_TC-12/vinvl/use_for_vinvl/feature.tsv",
                "--do_lower_case",
                "--output_dir",
                "output/try_contrastive_IAPR",
                "--eval_model_dir",
                "/raid/whf/ObjectContrastiveTransformer/output/IAPR_16/checkpoint-150",
                "--training_size","10000",
                "--query_size","2000",
                "--database_size","18000",
                "--class_number","255",
                "--class_name","/raid/data_modal/IAPR_TC-12/class_name.json",
                "--bit","16"
            ]

args = parser.parse_args(lists)
args.output = "output/IAPR_16"

In [20]:
if(args.training_size != -1):
    opt.training_size = args.training_size
    opt.query_size = args.query_size
    opt.database_size = args.database_size
global logger
mkdir(args.output_dir)
logger = setup_logger("vlpretrain", args.output_dir, 0)
opt.bit = args.bit
opt.class_number = args.class_number
device = torch.device("cuda")
config_class, tokenizer_class = BertConfig, BertTokenizer
checkpoint = args.eval_model_dir
tokenizer = tokenizer_class.from_pretrained(checkpoint, do_lower_case=args.do_lower_case)
config = config_class.from_pretrained(checkpoint)
config.class_number = opt.class_number
config.bit = args.bit
model = HashingformerALL(None,config)

In [21]:
if(not args.no_pretrain):

    if(not os.path.exists(checkpoint+"/pytorch_model.bin")):
        sd = torch.load(checkpoint+"/model.cpkt", map_location="cpu")
    else:
        sd = torch.load(checkpoint+"/pytorch_model.bin", map_location="cpu")
    missing, unexpected = model.load_state_dict(sd, strict=False)  
model.to(device)
query_dataset = RetrievalDataset(args,tokenizer,"query")
retrieval_dataset = RetrievalDataset(args,tokenizer,"retrieval")


100%|██████████| 17585/17585 [00:01<00:00, 11397.60it/s]
100%|██████████| 17585/17585 [00:07<00:00, 2456.70it/s]


In [22]:
model.eval()
query_sampler = SequentialSampler(query_dataset)
query_dataloader = DataLoader(query_dataset, sampler=query_sampler,
        batch_size=512, num_workers=4)
retrieval_sampler = SequentialSampler(retrieval_dataset)
retrieval_dataloader = DataLoader(retrieval_dataset, sampler=retrieval_sampler,
        batch_size=512, num_workers=4)
qBX,query_L_i = generate_code(model,query_dataloader ,TorI="I")
qBY,query_L_t = generate_code(model, query_dataloader, TorI="T")
rBX,retrieval_L_i = generate_code(model, retrieval_dataloader,TorI="I")
rBY,retrieval_L_t = generate_code(model, retrieval_dataloader, TorI="T")

100%|██████████| 4/4 [00:05<00:00,  1.47s/it]
100%|██████████| 4/4 [00:03<00:00,  1.02it/s]
100%|██████████| 31/31 [00:33<00:00,  1.08s/it]
100%|██████████| 31/31 [00:17<00:00,  1.72it/s]


In [23]:
from hashing_module.utils import calc_map_k_final

In [24]:
mapi2t,_,pr_i2t = calc_map_k_final(torch.sign(qBX), torch.sign(rBY), query_L_i, retrieval_L_t)
mapt2i,_,pr_t2i = calc_map_k_final(torch.sign(qBY), torch.sign(rBX), query_L_t, retrieval_L_i)
pr={"pr_i2t":pr_i2t,"pr_t2i":pr_t2i}

  1%|          | 12/2000 [00:00<00:17, 113.00it/s]

calc map k


100%|██████████| 2000/2000 [00:17<00:00, 115.00it/s]
  1%|          | 12/2000 [00:00<00:16, 117.84it/s]

calc map k


100%|██████████| 2000/2000 [00:17<00:00, 115.45it/s]


In [25]:
with open(os.path.join(args.output,"pr.json"),"w") as f:
    json.dump(pr,f)   

In [26]:
pr

{'pr_i2t': [{'TP': 613,
   'FP': 55,
   'TN': 21717461,
   'FN': 9451871,
   'P': 0.9176646706586826,
   'R': 6.485067840368733e-05},
  {'TP': 9023,
   'FP': 781,
   'TN': 21716735,
   'FN': 9443461,
   'P': 0.9203386372909017,
   'R': 0.0009545639008751562},
  {'TP': 79675,
   'FP': 8760,
   'TN': 21708756,
   'FN': 9372809,
   'P': 0.900944196302369,
   'R': 0.008429001308015967},
  {'TP': 339599,
   'FP': 51137,
   'TN': 21666379,
   'FN': 9112885,
   'P': 0.8691264690225625,
   'R': 0.0359269584587501},
  {'TP': 966267,
   'FP': 183098,
   'TN': 21534418,
   'FN': 8486217,
   'P': 0.8406963845253683,
   'R': 0.1022236059854743},
  {'TP': 1968924,
   'FP': 445459,
   'TN': 21272057,
   'FN': 7483560,
   'P': 0.8154977897044504,
   'R': 0.2082969936791218},
  {'TP': 3253019,
   'FP': 930811,
   'TN': 20786705,
   'FN': 6199465,
   'P': 0.7775217922334321,
   'R': 0.3441443540131885},
  {'TP': 4662885,
   'FP': 1936236,
   'TN': 19781280,
   'FN': 4789599,
   'P': 0.7065918324576864,


In [12]:
print(mapi2t)
print(mapt2i)

tensor(0.6086, device='cuda:0')
tensor(0.5814, device='cuda:0')


In [13]:
pr

{'pr_i2t': [{'TP': 701,
   'FP': 19,
   'TN': 21717497,
   'FN': 9451783,
   'P': 0.9736111111111111,
   'R': 7.416040058888225e-05},
  {'TP': 7177,
   'FP': 463,
   'TN': 21717053,
   'FN': 9445307,
   'P': 0.9393979057591623,
   'R': 0.000759271319581181},
  {'TP': 66733,
   'FP': 7610,
   'TN': 21709906,
   'FN': 9385751,
   'P': 0.8976366302140081,
   'R': 0.00705983739300696},
  {'TP': 291338,
   'FP': 46934,
   'TN': 21670582,
   'FN': 9161146,
   'P': 0.8612536656891495,
   'R': 0.030821316386253603},
  {'TP': 894223,
   'FP': 180538,
   'TN': 21536978,
   'FN': 8558261,
   'P': 0.832020328240418,
   'R': 0.09460190570013131},
  {'TP': 1934983,
   'FP': 470687,
   'TN': 21246829,
   'FN': 7517501,
   'P': 0.8043426571391754,
   'R': 0.20470629730767065},
  {'TP': 3367698,
   'FP': 1015851,
   'TN': 20701665,
   'FN': 6084786,
   'P': 0.768258322195098,
   'R': 0.3562765089049609},
  {'TP': 4857378,
   'FP': 2064380,
   'TN': 19653136,
   'FN': 4595106,
   'P': 0.7017549587835923