In [1]:
import numpy as np
import os.path as osp
import pickle, json, random

In [2]:
import os, math
import os.path as osp
from copy import deepcopy
from functools import partial
from pprint import pprint

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from torch.backends import cudnn
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm

from torch.backends import cudnn
# from visdom_logger import VisdomLogger

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT, PosMatchERT
from models.ingredient import model_ingredient, get_model

In [6]:
from models.ingredient import model_ingredient, get_model
from utils import state_dict_to_cpu, num_of_trainable_params
from utils import pickle_load, pickle_save
#from utils.data.utils import TripletSampler
from utils import BinaryCrossEntropyWithLogits
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.training import train_one_epoch, evaluate_viquae

In [7]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset

In [8]:
ex = sacred.Experiment('RRT Training', ingredients=[data_ingredient, model_ingredient], interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [85]:
epochs = 1
lr = 0.0001
momentum = 0.
nesterov = False
weight_decay = 5e-4
optim = 'adamw'
scheduler = 'multistep'
max_norm = 0.0
seed = 0

visdom_port = None
visdom_freq = 100
cpu = False  # Force training on CPU
cudnn_flag = 'benchmark'
temp_dir = osp.join('outputs', 'temp')

no_bias_decay = False
loss = 'bce'
scheduler_tau = [16, 18]
scheduler_gamma = 0.1

resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv2.pt'
#resume = None

In [86]:
epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, max_norm, resume

(1,
 False,
 'benchmark',
 None,
 100,
 'outputs/temp',
 0,
 False,
 0.0,
 '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv2.pt')

In [87]:
def get_optimizer_scheduler(parameters, optim, loader_length, epochs, lr, momentum, nesterov, weight_decay, scheduler, scheduler_tau, scheduler_gamma, lr_step=None):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True if nesterov and momentum else False)
    elif optim == 'adam':
        optimizer = Adam(parameters, lr=lr, weight_decay=weight_decay) 
    else:
        optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay)
    
    if epochs == 0:
        scheduler = None
        update_per_iteration = None
    elif scheduler == 'cos':
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * loader_length, eta_min=0.000005)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.000001)
        update_per_iteration = False
    elif scheduler == 'warmcos':
        # warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / (epochs * loader_length))) / 2)
        warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / epochs)) / 2)
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_cosine)
        update_per_iteration = False
    elif scheduler == 'multistep':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_tau, gamma=scheduler_gamma)
        update_per_iteration = False
    elif scheduler == 'warmstep':
        warm_step = lambda i: min((i + 1) / 100, 1) * 0.1 ** (i // (lr_step * loader_length))
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_step)
        update_per_iteration = True
    else:
        scheduler = lr_scheduler.StepLR(optimizer, epochs * loader_length)
        update_per_iteration = True

    return optimizer, (scheduler, update_per_iteration)



In [88]:
def get_loss(loss):
    if loss == 'bce':
        return BinaryCrossEntropyWithLogits()
    else:
        raise Exception('Unsupported loss {}'.format(loss))

In [89]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
# callback = VisdomLogger(port=visdom_port) if visdom_port else None
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [90]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [91]:
name = 'tuto_viquae_tuto_r50_gldv2'
set_name = 'tuto'
train_txt = ('tuto_query.txt', 'tuto_gallery.txt')
test_txt = ('tuto_query.txt', 'tuto_selection.txt')
train_data_dir = 'data/viquae_for_rrt'
test_data_dir  = 'data/viquae_for_rrt'
test_gnd_file = 'gnd_tuto.pkl'
train_gnd_file = 'gnd_'+set_name+'.pkl'
#train_gnd_file = 'gnd_'+set_name+'.pkl'
desc_name = 'r50_gldv2'
sampler = 'triplet'
split_char  = ';;'

In [92]:
"""name = 'train_viquae_dev_r50_gldv2'
set_name = 'train'
train_txt = (set_name+'_query.txt', set_name+'_gallery.txt')
test_txt = ('dev_query.txt', 'dev_selection.txt')
train_data_dir = 'data/viquae_for_rrt'
test_data_dir  = 'data/viquae_for_rrt'
train_gnd_file = 'gnd_train.pkl'
test_gnd_file = 'gnd_dev.pkl'
desc_name = 'r50_gldv2'
sampler = 'triplet'
split_char  = ';;'"""

"name = 'train_viquae_dev_r50_gldv2'\nset_name = 'train'\ntrain_txt = (set_name+'_query.txt', set_name+'_gallery.txt')\ntest_txt = ('dev_query.txt', 'dev_selection.txt')\ntrain_data_dir = 'data/viquae_for_rrt'\ntest_data_dir  = 'data/viquae_for_rrt'\ntrain_gnd_file = 'gnd_train.pkl'\ntest_gnd_file = 'gnd_dev.pkl'\ndesc_name = 'r50_gldv2'\nsampler = 'triplet'\nsplit_char  = ';;'"

In [93]:
len(train_txt)

2

In [94]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, train_gnd_file,
        test_gnd_file, max_sequence_len,
        split_char):
    ####################################################################################################################################
    train_gnd_data  = None if train_gnd_file is None else pickle_load(osp.join(train_data_dir, train_gnd_file))
    train_lines     = read_file(osp.join(train_data_dir, train_txt[1]))
    train_q_lines   = read_file(osp.join(train_data_dir, train_txt[0]))
    train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
    train_q_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_q_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples,   desc_name, max_sequence_len, gnd_data=train_gnd_data)
    query_train_set = FeatureDataset(train_data_dir, train_q_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in query_lines]
    gallery_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)


In [95]:
def get_loaders(desc_name, train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name,
    num_candidates=300,):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(
        desc_name=desc_name, train_data_dir=train_data_dir, 
        test_data_dir=train_data_dir, train_txt=train_txt, 
        test_txt=test_txt, train_gnd_file=train_gnd_file, 
        test_gnd_file=test_gnd_file, 
        max_sequence_len=max_sequence_len, 
        split_char=split_char)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        #s_name = set_name
        #if s_name != '':
        #    s_name = set_name + '_'
        #def map_nnids_labels(train_data_dir, train_gnd_file, s_categories):
        #    gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))
        #    selection_gallery = gnd['simlist']
        #    s_categories = s_categories.reshape(np.array(selection_gallery).shape)
        #    selection_ids_to_cat_dict = [{k: s_categories[i][k] for k in range(len(selection_gallery[i]))} for i in range(len(selection_gallery))]
        #    return selection_ids_to_cat_dict
        #s_path = train_data_dir+'/'+set_name+'_s_categories.txt'
        #print('s_path: ', s_path)
        #s_categories = np.loadtxt(s_path, dtype='int64')
        #print('s_categories: ', s_categories.shape)
        #map_nnids_labels = map_nnids_labels(train_data_dir, train_gnd_file, s_categories)
        train_nn_inds = osp.join(train_data_dir, set_name + '_nn_inds_%s.pkl'%desc_name)
        gnd_data = train_set.gnd_data['gnd']
        train_sampler = TripletSampler(query_train_set.targets, batch_size, train_nn_inds, num_candidates, gnd_data)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories),set_name=set_name), recalls


In [96]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [97]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    set_name: str = ''
    gallery: Optional[DataLoader] = None

In [98]:
batch_size      = 72
test_batch_size = 72
max_sequence_len = 500
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 10]

In [99]:
36*2

72

In [100]:
class TripletSampler():
    def __init__(self, labels, batch_size, nn_inds_path, num_candidates, gnd_data, min_pos=1):
        self.batch_size     = batch_size
        self.num_candidates = num_candidates
        self.cache_nn_inds  = pickle_load(nn_inds_path)
        self.labels = labels
        self.gnd_data = gnd_data
        print('nn_inds_path: ', nn_inds_path)
        print('labels len: ', len(labels))
        assert (len(self.cache_nn_inds) == len(labels))
        #############################################################################
        ## Collect valid tuples
        valids = np.zeros_like(labels)
        for i in range(len(self.cache_nn_inds)):
            positives = self.gnd_data[i]['r_easy']
            negatives = self.gnd_data[i]['r_junk']
            if len(positives) < min_pos or len(negatives) < min_pos:
                continue
            valids[i] = 1
        self.valids = np.where(valids > 0)[0]
        self.num_samples = len(self.valids)

    def __iter__(self):
        batch = []
        cands = torch.randperm(self.num_samples).tolist()
        for i in range(len(cands)):
            query_idx = self.valids[cands[i]]
            anchor_idx = self.gnd_data[query_idx]['anchor_idx']
            
            positive_inds = self.gnd_data[query_idx]['g_easy']
            negative_inds = self.gnd_data[query_idx]['g_junk']
            assert(len(positive_inds) > 0)
            assert(len(negative_inds) > 0)

            random.shuffle(positive_inds)
            random.shuffle(negative_inds)

            batch.append(anchor_idx)
            batch.append(positive_inds[0]) 
            batch.append(negative_inds[0])

            if len(batch) >= self.batch_size:
                yield batch
                batch = []
                
        if len(batch) > 0:
            yield batch

    def __len__(self):
        return (self.num_samples * 3 + self.batch_size - 1) // self.batch_size


In [101]:
nn_inds_path = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/'+set_name+'_nn_inds_r50_gldv1.pkl'
cache_nn_inds  = pickle_load(nn_inds_path)
cache_nn_inds.shape

(120, 100)

In [102]:
torch.manual_seed(seed)
loaders, recall_ks = get_loaders('r50_gldv2', train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name,
    num_candidates=100)

nn_inds_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/tuto_nn_inds_r50_gldv2.pkl
labels len:  120


In [103]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before, use_pos):
    
    if use_pos:
        return PosMatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)
    
    else:
        return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [104]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.2
activation = "relu"
normalize_before = False
use_pos = False

In [105]:
"""
name = 'vrrt'
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 8
num_encoder_layers = 8
dropout = 0.4 
activation = "relu"
normalize_before = False
use_pos = True
"""

'\nname = \'vrrt\'\nseq_len = 1004\ndim_K = 256\ndim_feedforward = 1024\nnhead = 8\nnum_encoder_layers = 8\ndropout = 0.4 \nactivation = "relu"\nnormalize_before = False\nuse_pos = True\n'

In [106]:
torch.manual_seed(seed+1)
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before,use_pos)

In [107]:
model.eval()

MatchERT(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Lin

In [108]:
if resume is not None:
    checkpoint = torch.load(resume, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state'], strict=True)
print('# of trainable parameters: ', num_of_trainable_params(model))
class_loss = get_loss(loss)
nn_inds_path = osp.join(loaders.query.dataset.data_dir, loaders.set_name + '_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

# of trainable parameters:  2243201


In [109]:
(129, 2243201) 

(129, 2243201)

In [110]:
torch.manual_seed(seed+2)
model.to(device)
model = nn.DataParallel(model)
parameters = []
if no_bias_decay:
    parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]})
    parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0})
else:
    parameters.append({'params': model.parameters()})
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train),
                                               optim=optim, epochs=epochs, lr=lr,
                                               momentum=momentum, nesterov=nesterov, weight_decay=weight_decay,
                                               scheduler=scheduler, scheduler_tau=scheduler_tau, 
                                               scheduler_gamma=scheduler_gamma, lr_step=None)
if resume is not None and checkpoint.get('optim', None) is not None:
    optimizer.load_state_dict(checkpoint['optim'])
    del checkpoint


In [111]:
from utils.metrics import *
def evaluate_viquae(
        model: nn.Module,
        cache_nn_inds: torch.Tensor,
        query_loader: DataLoader,
        gallery_loader: DataLoader,
        recall: List[int]):
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)

    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))
            torch.cuda.empty_cache()

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            torch.cuda.empty_cache()
            
        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)

        torch.cuda.empty_cache()
        evaluate_function = partial(mean_average_precision_viquae_rerank, model=model, cache_nn_inds=cache_nn_inds,
            query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
            gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
            ks=recall, 
            gnd=query_loader.dataset.gnd_data,
        )
        metrics = evaluate_function()
    return metrics 



In [112]:
def read_entry_once(
    query_loader: DataLoader,
    gallery_loader: DataLoader):
    
    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    print("READING ENTRIES FOR THE FIRST TIME")
    
    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))
            torch.cuda.empty_cache()

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            torch.cuda.empty_cache()

        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)
    
    query_feats   = [query_global, query_local, query_mask, query_scales, query_positions, query_names]
    gallery_feats = [gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names]
    
    return query_feats, gallery_feats

In [113]:
from utils.metrics import mean_average_precision_viquae_rerank

def fast_evaluate_viquae(
    model: nn.Module,
    cache_nn_inds: torch.Tensor,
    query_loader: DataLoader,
    gallery_loader: DataLoader,
    recall: List[int],
    query_feats, 
    gallery_feats):
    
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    
    if len(query_feats) == 0:
        query_feats, gallery_feats = read_entry_once(query_loader, gallery_loader)
    
    query_global, query_local, query_mask, query_scales, query_positions, query_names = query_feats
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = gallery_feats
    
    torch.cuda.empty_cache()
    
    fast_evaluate_function = partial(mean_average_precision_viquae_rerank, model=model, cache_nn_inds=cache_nn_inds,
        query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
        gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
        ks=recall, 
        gnd=query_loader.dataset.gnd_data,
    )
    metrics = fast_evaluate_function()
    
    return metrics, query_feats, gallery_feats

In [114]:
query_feats, gallery_feats = [], []

In [115]:
torch.manual_seed(seed+3)
# setup partial function to simplify call
#query_feats, gallery_feats = [], []

eval_function = partial(fast_evaluate_viquae, model=model, 
    cache_nn_inds=cache_nn_inds,
    recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery,
    query_feats=query_feats, gallery_feats=gallery_feats)

In [116]:
result, query_feats, gallery_feats = eval_function()
pprint(result)
best_val = (0, result, deepcopy(model.state_dict()))

Extracting query features:   0%|                          | 0/2 [00:00<?, ?it/s]

READING ENTRIES FOR THE FIRST TIME


100%|██████████| 100/100 [02:43<00:00,  1.63s/it]                               

{'map': 31.63, 'mrr': 44.02, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.82, 'mrr@1': 33.33, 'precision@1': 33.33, 'hit_rate@1': 33.33, 'recall@1': 11.82, 'map@5': 18.28, 'mrr@5': 41.64, 'precision@5': 19.83, 'hit_rate@5': 55.0, 'recall@5': 22.65, 'map@10': 21.45, 'mrr@10': 43.08, 'precision@10': 16.67, 'hit_rate@10': 65.0, 'recall@10': 31.73}
{'hit_rate': 85.0,
 'hit_rate@1': 33.33,
 'hit_rate@10': 65.0,
 'hit_rate@5': 55.0,
 'map': 31.63,
 'map@1': 11.82,
 'map@10': 21.45,
 'map@5': 18.28,
 'mrr': 44.02,
 'mrr@1': 33.33,
 'mrr@10': 43.08,
 'mrr@5': 41.64,
 'precision': 13.83,
 'precision@1': 33.33,
 'precision@10': 16.67,
 'precision@5': 19.83,
 'recall': 84.58,
 'recall@1': 11.82,
 'recall@10': 31.73,
 'recall@5': 22.65}





In [117]:
torch.manual_seed(seed+4)

<torch._C.Generator at 0x7f0ab8eb5810>

In [118]:
#save_name = osp.join(temp_dir, '{}_{}.pt'.format(ex.current_run.config['model']['name'],
#                                                         ex.current_run.config['dataset']['name']))

In [119]:
result

{'map': 31.63,
 'mrr': 44.02,
 'precision': 13.83,
 'hit_rate': 85.0,
 'recall': 84.58,
 'map@1': 11.82,
 'mrr@1': 33.33,
 'precision@1': 33.33,
 'hit_rate@1': 33.33,
 'recall@1': 11.82,
 'map@5': 18.28,
 'mrr@5': 41.64,
 'precision@5': 19.83,
 'hit_rate@5': 55.0,
 'recall@5': 22.65,
 'map@10': 21.45,
 'mrr@10': 43.08,
 'precision@10': 16.67,
 'hit_rate@10': 65.0,
 'recall@10': 31.73}

In [120]:
save_name = 'temp_outputs/models/rrt_tuto_viquae_dev_r50_gldv2.pt'

In [121]:
for epoch in range(5):
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, True)

    torch.cuda.empty_cache()
    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)
    
    # validation
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, False)
    
    torch.cuda.empty_cache()
    result, query_feats, gallery_feats = fast_evaluate_viquae(model=model,
                                                              cache_nn_inds=cache_nn_inds,
                                                              recall=recall_ks,
                                                              query_loader=loaders.query, 
                                                              gallery_loader=loaders.gallery,
                                                              query_feats=query_feats, 
                                                              gallery_feats=gallery_feats)
    
    print('Validation [{:03d}]'.format(epoch)), pprint(result)
    #ex.log_scalar('val.map', result['map'], step=epoch + 1)

    if result['map'] >= best_val[1]['map']:
        print('New best model in epoch %d.'%epoch)
        best_val = (epoch + 1, result, deepcopy(model.state_dict()))
        torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)

Training   [000]: 100%|███████████████████████████| 5/5 [00:09<00:00,  1.91s/it]
100%|██████████| 100/100 [03:17<00:00,  1.98s/it]
Training   [001]:   0%|                                   | 0/5 [00:00<?, ?it/s]

{'map': 31.38, 'mrr': 43.39, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.61, 'mrr@1': 32.5, 'precision@1': 32.5, 'hit_rate@1': 32.5, 'recall@1': 11.61, 'map@5': 17.82, 'mrr@5': 40.74, 'precision@5': 19.33, 'hit_rate@5': 53.33, 'recall@5': 21.61, 'map@10': 21.2, 'mrr@10': 42.45, 'precision@10': 16.67, 'hit_rate@10': 65.0, 'recall@10': 31.88}
Validation [000]
{'hit_rate': 85.0,
 'hit_rate@1': 32.5,
 'hit_rate@10': 65.0,
 'hit_rate@5': 53.33,
 'map': 31.38,
 'map@1': 11.61,
 'map@10': 21.2,
 'map@5': 17.82,
 'mrr': 43.39,
 'mrr@1': 32.5,
 'mrr@10': 42.45,
 'mrr@5': 40.74,
 'precision': 13.83,
 'precision@1': 32.5,
 'precision@10': 16.67,
 'precision@5': 19.33,
 'recall': 84.58,
 'recall@1': 11.61,
 'recall@10': 31.88,
 'recall@5': 21.61}


Training   [001]: 100%|███████████████████████████| 5/5 [00:11<00:00,  2.34s/it]
100%|██████████| 100/100 [02:44<00:00,  1.65s/it]
Training   [002]:   0%|                                   | 0/5 [00:00<?, ?it/s]

{'map': 31.43, 'mrr': 44.18, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.75, 'mrr@1': 34.17, 'precision@1': 34.17, 'hit_rate@1': 34.17, 'recall@1': 11.75, 'map@5': 17.71, 'mrr@5': 41.35, 'precision@5': 19.17, 'hit_rate@5': 51.67, 'recall@5': 20.62, 'map@10': 21.44, 'mrr@10': 43.25, 'precision@10': 17.17, 'hit_rate@10': 65.0, 'recall@10': 32.6}
Validation [001]
{'hit_rate': 85.0,
 'hit_rate@1': 34.17,
 'hit_rate@10': 65.0,
 'hit_rate@5': 51.67,
 'map': 31.43,
 'map@1': 11.75,
 'map@10': 21.44,
 'map@5': 17.71,
 'mrr': 44.18,
 'mrr@1': 34.17,
 'mrr@10': 43.25,
 'mrr@5': 41.35,
 'precision': 13.83,
 'precision@1': 34.17,
 'precision@10': 17.17,
 'precision@5': 19.17,
 'recall': 84.58,
 'recall@1': 11.75,
 'recall@10': 32.6,
 'recall@5': 20.62}


Training   [002]: 100%|███████████████████████████| 5/5 [00:09<00:00,  1.90s/it]
100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Training   [003]:   0%|                                   | 0/5 [00:00<?, ?it/s]

{'map': 31.35, 'mrr': 43.71, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.54, 'mrr@1': 33.33, 'precision@1': 33.33, 'hit_rate@1': 33.33, 'recall@1': 11.54, 'map@5': 17.7, 'mrr@5': 40.86, 'precision@5': 19.67, 'hit_rate@5': 51.67, 'recall@5': 20.78, 'map@10': 21.34, 'mrr@10': 42.91, 'precision@10': 17.25, 'hit_rate@10': 66.67, 'recall@10': 32.85}
Validation [002]
{'hit_rate': 85.0,
 'hit_rate@1': 33.33,
 'hit_rate@10': 66.67,
 'hit_rate@5': 51.67,
 'map': 31.35,
 'map@1': 11.54,
 'map@10': 21.34,
 'map@5': 17.7,
 'mrr': 43.71,
 'mrr@1': 33.33,
 'mrr@10': 42.91,
 'mrr@5': 40.86,
 'precision': 13.83,
 'precision@1': 33.33,
 'precision@10': 17.25,
 'precision@5': 19.67,
 'recall': 84.58,
 'recall@1': 11.54,
 'recall@10': 32.85,
 'recall@5': 20.78}


Training   [003]: 100%|███████████████████████████| 5/5 [00:09<00:00,  1.90s/it]
100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Training   [004]:   0%|                                   | 0/5 [00:00<?, ?it/s]

{'map': 31.28, 'mrr': 43.62, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.36, 'mrr@1': 32.5, 'precision@1': 32.5, 'hit_rate@1': 32.5, 'recall@1': 11.36, 'map@5': 17.77, 'mrr@5': 41.0, 'precision@5': 20.17, 'hit_rate@5': 52.5, 'recall@5': 21.36, 'map@10': 21.32, 'mrr@10': 42.75, 'precision@10': 17.58, 'hit_rate@10': 65.83, 'recall@10': 32.37}
Validation [003]
{'hit_rate': 85.0,
 'hit_rate@1': 32.5,
 'hit_rate@10': 65.83,
 'hit_rate@5': 52.5,
 'map': 31.28,
 'map@1': 11.36,
 'map@10': 21.32,
 'map@5': 17.77,
 'mrr': 43.62,
 'mrr@1': 32.5,
 'mrr@10': 42.75,
 'mrr@5': 41.0,
 'precision': 13.83,
 'precision@1': 32.5,
 'precision@10': 17.58,
 'precision@5': 20.17,
 'recall': 84.58,
 'recall@1': 11.36,
 'recall@10': 32.37,
 'recall@5': 21.36}


Training   [004]: 100%|███████████████████████████| 5/5 [00:09<00:00,  1.94s/it]
100%|██████████| 100/100 [01:12<00:00,  1.37it/s]

{'map': 31.18, 'mrr': 43.86, 'precision': 13.83, 'hit_rate': 85.0, 'recall': 84.58, 'map@1': 11.39, 'mrr@1': 33.33, 'precision@1': 33.33, 'hit_rate@1': 33.33, 'recall@1': 11.39, 'map@5': 17.65, 'mrr@5': 41.28, 'precision@5': 20.33, 'hit_rate@5': 52.5, 'recall@5': 21.27, 'map@10': 21.32, 'mrr@10': 42.95, 'precision@10': 17.67, 'hit_rate@10': 65.0, 'recall@10': 32.18}
Validation [004]
{'hit_rate': 85.0,
 'hit_rate@1': 33.33,
 'hit_rate@10': 65.0,
 'hit_rate@5': 52.5,
 'map': 31.18,
 'map@1': 11.39,
 'map@10': 21.32,
 'map@5': 17.65,
 'mrr': 43.86,
 'mrr@1': 33.33,
 'mrr@10': 42.95,
 'mrr@5': 41.28,
 'precision': 13.83,
 'precision@1': 33.33,
 'precision@10': 17.67,
 'precision@5': 20.33,
 'recall': 84.58,
 'recall@1': 11.39,
 'recall@10': 32.18,
 'recall@5': 21.27}





In [122]:
from utils.training import *
loader=loaders.train
class_loss=class_loss
optimizer=optimizer
scheduler=scheduler
max_norm=max_norm
epoch=epoch
freq=visdom_freq
ex=None

In [123]:
llm, vvm

NameError: name 'llm' is not defined

In [None]:
model.train()
device = next(model.parameters()).device
to_device = lambda x: x.to(device, non_blocking=True)
loader_length = len(loader)
train_losses = AverageMeter(device=device, length=loader_length)
train_accs = AverageMeter(device=device, length=loader_length)
pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
for i, entry in enumerate(pbar):
    global_feats, local_feats, local_mask, scales, positions, llm, vvm = entry
    global_feats, local_feats, local_mask, scales, positions = map(to_device, (global_feats, local_feats, local_mask, scales, positions))

    p_logits = model(global_feats[0::3], local_feats[0::3], local_mask[0::3], scales[0::3], positions[0::3],
        global_feats[1::3], local_feats[1::3], local_mask[1::3], scales[1::3], positions[1::3])
    n_logits = model(global_feats[0::3], local_feats[0::3], local_mask[0::3], scales[0::3], positions[0::3],
        global_feats[2::3], local_feats[2::3], local_mask[2::3], scales[2::3], positions[2::3])

    logits = torch.cat([p_logits, n_logits], 0)
    bsize = logits.size(0)
    # assert (bsize % 2 == 0)
    labels = logits.new_ones(logits.size()).float()
    labels[(bsize//2):] = 0
    loss = class_loss(logits, labels).mean()
    acc = ((torch.sigmoid(logits) > 0.5).long() == labels.long()).float().mean()

    ##############################################
    optimizer.zero_grad()
    loss.backward()
    if max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
    optimizer.step()

    if scheduler[-1]:
        scheduler[0].step()

    train_losses.append(loss)
    train_accs.append(acc)


    if not (i + 1) % freq:
        step = epoch + i / loader_length
        print('step/loss/accu/lr:', step, train_losses.last_avg.item(), train_accs.last_avg.item(), scheduler[0].get_last_lr()[0])


if not scheduler[-1]:
    scheduler[0].step()


In [69]:
gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))

In [70]:
gnd['gnd'][0].keys()

dict_keys(['easy', 'hard', 'junk', 'neg', 'provenance_entity', 'ir_order', 'img_rank_dict', 'rank_img_dict', 'r_easy', 'r_hard', 'r_junk', 'r_neg', 'r_ir_order', 'anchor_idx', 'g_easy', 'g_hard', 'g_junk', 'g_neg'])

In [71]:
train_gnd_data  = None if train_gnd_file is None else pickle_load(osp.join(train_data_dir, train_gnd_file))
train_lines     = read_file(osp.join(train_data_dir, train_txt[1]))
train_q_lines   = read_file(osp.join(train_data_dir, train_txt[0]))
train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
train_q_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_q_lines]
train_set       = FeatureDataset(train_data_dir, train_samples,   desc_name, max_sequence_len, gnd_data=train_gnd_data)
query_train_set = FeatureDataset(train_data_dir, train_q_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)

In [72]:
train_g_lines   = read_file(osp.join(train_data_dir, 'train_gallery.txt'))


In [73]:
len(train_lines), len(train_q_lines), len(train_g_lines)

(5564, 120, 42678)

In [74]:
#s_name = set_name
#if s_name != '':
#    s_name = set_name + '_'
#def map_nnids_labels(train_data_dir, train_gnd_file, s_categories):
#    gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))
#    selection_gallery = gnd['simlist']
#    s_categories = s_categories.reshape(np.array(selection_gallery).shape)
#    selection_ids_to_cat_dict = [{k: s_categories[i][k] for k in range(len(selection_gallery[i]))} for i in range(len(selection_gallery))]
#    print(s_categories.shape)
#
#    return selection_ids_to_cat_dict
#
#s_path = train_data_dir+'/'+set_name+'_s_categories.txt'
#s_categories = np.loadtxt(s_path, dtype='int64')
#map_nnids_labels = map_nnids_labels(train_data_dir, train_gnd_file, s_categories)

train_nn_inds = osp.join(train_data_dir, set_name + '_nn_inds_%s.pkl'%desc_name)
gnd_data = train_set.gnd_data['gnd']
train_sampler = TripletSampler(query_train_set.targets, batch_size, train_nn_inds, num_candidates, gnd_data)

nn_inds_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/tuto_nn_inds_r50_gldv2.pkl
labels len:  120


In [75]:
#nn_inds_path =  osp.join(train_data_dir, 'training_' + s_name+'nn_inds_%s.pkl'%desc_name)
#cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()
#cache_nn_inds.shape

In [76]:
len(train_sampler.valids), np.where(train_sampler.valids > 0)[0]

(115,
 array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114]))

In [81]:
for item in train_sampler:
    print(item)

[690, 3193, 142, 4754, 4481, 2557, 1210, 4313, 4482, 4992, 450, 2275, 2765, 2796, 2641, 4752, 314, 4074, 2635, 3765, 2522, 3787, 1688, 4289, 2516, 1362, 5474, 743, 2118, 3898, 1141, 1430, 2273, 3010, 5470, 3566, 1015, 426, 5515, 4274, 1625, 5200, 2884, 3444, 1944, 132, 132, 4016, 278, 854, 3410, 5049, 5018, 448, 1551, 4038, 161, 4116, 504, 1589, 837, 4954, 1619, 2466, 2868, 1458, 3404, 2505, 5319, 3499, 488, 2241]
[629, 4676, 4450, 4091, 1806, 1826, 2462, 1358, 2127, 821, 2877, 1521, 3105, 228, 2477, 3103, 2652, 1321, 5256, 3701, 4250, 3329, 685, 1072, 4526, 2541, 1906, 1839, 5240, 3754, 438, 1000, 3134, 2238, 5242, 1341, 690, 3193, 1500, 779, 4281, 790, 4622, 4893, 3101, 1724, 4321, 4915, 302, 1638, 3502, 2493, 1402, 3210, 1741, 3033, 1315, 4693, 2937, 1790, 4023, 2046, 2550, 4253, 25, 1149, 5131, 240, 3967, 1073, 3720, 409]
[2238, 2517, 1732, 1976, 5527, 1654, 5396, 2903, 4112, 1533, 2078, 5305, 11, 1515, 2966, 876, 4134, 1390, 4970, 1849, 5032, 62, 2607, 3009, 1405, 3312, 2895, 2510

In [82]:
len(train_sampler)

5

In [83]:
95*9, len(train_sampler.valids), 9*3*36

(855, 115, 972)

In [84]:
len(item)

57

In [None]:
gnd['gnd'][10]['hard']

In [None]:
gnd['simlist'][1][6:17]