## Imports

In [1]:
import numpy as np
import os, math
import os.path as osp
from copy import deepcopy
from functools import partial
from pprint import pprint
import pickle, json, random

In [2]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from torch.backends import cudnn
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm

from torch.backends import cudnn
# from visdom_logger import VisdomLogger

In [3]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [4]:
from models.matcher import MatchERT, PosMatchERT
from models.ingredient import model_ingredient, get_model

In [5]:
from models.ingredient import model_ingredient, get_model
from utils import state_dict_to_cpu, num_of_trainable_params
from utils import pickle_load, pickle_save
#from utils.data.utils import TripletSampler
from utils import BinaryCrossEntropyWithLogits
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.training import train_one_epoch, evaluate_viquae
from utils.metrics import AverageMeter

In [6]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset

In [7]:
from functools import partial
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from sacred import Experiment
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset

In [8]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [9]:
ex = sacred.Experiment('RRT Training', ingredients=[data_ingredient, model_ingredient], interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

## Data Utils Functions

In [10]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [11]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, train_txt, 
        test_txt, train_gnd_file,  test_gnd_file,
        max_sequence_len, split_char, prefixed):
    ####################################################################################################################################
    train_gnd_file = prefixed+'_'+train_gnd_file if prefixed is not None and train_gnd_file is not None else train_gnd_file
    test_gnd_file  = prefixed+'_'+test_gnd_file  if prefixed is not None and test_gnd_file  is not None else test_gnd_file
    
    if len(train_txt) == 2:
        
        train_gnd_data  = None if train_gnd_file is None else pickle_load(osp.join(train_data_dir, train_gnd_file))
        train_lines_txt = train_txt[1] if prefixed is None else prefixed+'_'+train_txt[1]
        train_lines     = read_file(osp.join(train_data_dir, train_lines_txt))
        train_q_lines_txt = train_txt[0] if prefixed is None else prefixed+'_'+train_txt[0]
        train_q_lines   = read_file(osp.join(train_data_dir, train_q_lines_txt))
        train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
        train_q_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_q_lines]
        train_set       = FeatureDataset(train_data_dir, train_samples,   desc_name, max_sequence_len, gnd_data=train_gnd_data)
        query_train_set = FeatureDataset(train_data_dir, train_q_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)
    else:
        train_gnd_data  = None if train_gnd_file is None else pickle_load(osp.join(train_data_dir, train_gnd_file))
        train_lines_txt = train_txt if prefixed is None else prefixed+'_'+train_txt
        train_lines     = read_file(osp.join(train_data_dir, train_lines_txt))
        train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
        train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)
        query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)
        ####################################################################################################################################
    test_gnd_data   = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines_txt = test_txt[0] if prefixed is None else prefixed+'_'+test_txt[0]
    query_lines     = read_file(osp.join(test_data_dir, query_lines_txt))
    gallery_lines_txt = test_txt[1] if prefixed is None else prefixed+'_'+test_txt[1]
    gallery_lines   = read_file(osp.join(test_data_dir, gallery_lines_txt))
    query_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in query_lines]
    gallery_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in gallery_lines]
    gallery_set     = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set       = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)


In [12]:
def get_loaders(desc_name, train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name, 
    eval_set_name, train_gnd_file,
    prefixed, num_candidates=100):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name, 
        train_data_dir=train_data_dir,
        test_data_dir=train_data_dir,
        train_txt=set_name+'_query.txt',
        test_txt=(set_name+'_query.txt', set_name+'_selection.txt'),
        test_gnd_file=test_gnd_file, 
        train_gnd_file=train_gnd_file,
        split_char=split_char,
        prefixed=prefixed,
        max_sequence_len=500)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        nn_inds_path = set_name+'_nn_inds_%s.pkl'%desc_name if prefixed is None else prefixed+'_'+set_name + '_nn_inds_%s.pkl'%desc_name
        train_nn_inds = osp.join(train_data_dir, nn_inds_path)
        gnd_data = train_set.gnd_data['gnd']
        train_sampler = TripletSampler(query_train_set.targets, batch_size, train_nn_inds, num_candidates, gnd_data)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories),set_name=set_name,eval_set_name=eval_set_name,prefixed=prefixed), recalls


In [13]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    prefixed: str = None
    set_name: str = ''
    eval_set_name: str = ''
    gallery: Optional[DataLoader] = None

## Training Parameters

In [14]:
epochs = 5
lr = 0.0001
momentum = 0.
nesterov = False
weight_decay = 5e-4
optim = 'adamw'
scheduler = 'multistep'
max_norm = 0.0
seed = 0

visdom_port = None
visdom_freq = 100
cpu = False  # Force training on CPU
cudnn_flag = 'benchmark'
temp_file = 'temp'

no_bias_decay = False
loss = 'bce'
scheduler_tau = [16, 18]
scheduler_gamma = 0.1

temp_dir = osp.join('outputs', 'temp')

resume = None
resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv2.pt'
classifier = False
transformer = False
last_layers = False

In [15]:
epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, max_norm, resume

(5,
 False,
 'benchmark',
 None,
 100,
 'outputs/temp',
 0,
 False,
 0.0,
 '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv2.pt')

## Various Training Utility Functions

In [16]:
def get_optimizer_scheduler(parameters, optim, loader_length, epochs, lr, momentum, nesterov, weight_decay, scheduler, scheduler_tau, scheduler_gamma, lr_step=None):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True if nesterov and momentum else False)
    elif optim == 'adam':
        optimizer = Adam(parameters, lr=lr, weight_decay=weight_decay) 
    else:
        optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay)
    
    if epochs == 0:
        scheduler = None
        update_per_iteration = None
    elif scheduler == 'cos':
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * loader_length, eta_min=0.000005)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.000001)
        update_per_iteration = False
    elif scheduler == 'warmcos':
        # warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / (epochs * loader_length))) / 2)
        warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / epochs)) / 2)
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_cosine)
        update_per_iteration = False
    elif scheduler == 'multistep':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_tau, gamma=scheduler_gamma)
        update_per_iteration = False
    elif scheduler == 'warmstep':
        warm_step = lambda i: min((i + 1) / 100, 1) * 0.1 ** (i // (lr_step * loader_length))
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_step)
        update_per_iteration = True
    else:
        scheduler = lr_scheduler.StepLR(optimizer, epochs * loader_length)
        update_per_iteration = True
 
    return optimizer, (scheduler, update_per_iteration)


In [17]:
def get_loss(loss):
    if loss == 'bce':
        return BinaryCrossEntropyWithLogits()
    else:
        raise Exception('Unsupported loss {}'.format(loss))

In [18]:
class TripletSampler():
    def __init__(self, labels, batch_size, nn_inds_path, num_candidates, gnd_data, min_pos=1):
        self.batch_size     = batch_size
        self.num_candidates = num_candidates
        self.cache_nn_inds  = pickle_load(nn_inds_path)
        self.labels = labels
        self.gnd_data = gnd_data
        print('nn_inds_path: ', nn_inds_path)
        print('labels len: ', len(labels))
        assert (len(self.cache_nn_inds) == len(labels))
        #############################################################################
        ## Collect valid tuples
        valids = np.zeros_like(labels)
        for i in range(len(self.cache_nn_inds)):
            positives = self.gnd_data[i]['r_easy']
            negatives = self.gnd_data[i]['r_junk']
            if len(positives) < min_pos or len(negatives) < min_pos:
                continue
            valids[i] = 1
        self.valids = np.where(valids > 0)[0]
        self.num_samples = len(self.valids)

    def __iter__(self):
        batch = []
        cands = torch.randperm(self.num_samples).tolist()
        for i in range(len(cands)):
            query_idx = self.valids[cands[i]]
            anchor_idx = self.gnd_data[query_idx]['anchor_idx']
            
            positive_inds = self.gnd_data[query_idx]['g_easy']
            negative_inds = self.gnd_data[query_idx]['g_junk']
            assert(len(positive_inds) > 0)
            assert(len(negative_inds) > 0)

            random.shuffle(positive_inds)
            random.shuffle(negative_inds)

            batch.append(anchor_idx)
            batch.append(positive_inds[0]) 
            batch.append(negative_inds[0])

            if len(batch) >= self.batch_size:
                yield batch
                batch = []
                
        if len(batch) > 0:
            yield batch

    def __len__(self):
        return (self.num_samples * 3 + self.batch_size - 1) // self.batch_size


In [19]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before, use_pos):
    
    if use_pos:
        return PosMatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)
    
    else:
        return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [20]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
# callback = VisdomLogger(port=visdom_port) if visdom_port else None
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

## Config Parameters

In [21]:
name = 'tuto_viquae_tuto_r50_gldv2'
set_name = 'tuto'
train_txt = ('tuto_query.txt', 'tuto_gallery.txt')
test_txt = ('tuto_query.txt', 'tuto_selection.txt')
train_data_dir = 'data/viquae_for_rrt'
test_data_dir  = 'data/viquae_for_rrt'
test_gnd_file = 'gnd_tuto.pkl'
train_gnd_file = 'gnd_'+set_name+'.pkl'
#train_gnd_file = 'gnd_'+set_name+'.pkl'
desc_name = 'r50_gldv2'
sampler = 'triplet'
split_char  = ';;'
prefixed = 'non_humans'

In [22]:
"""name = 'train_viquae_dev_r50_gldv2'
set_name = 'train'
eval_set_name = 'dev'
train_txt = ('train_query.txt', 'train_gallery.txt')
test_txt = ('dev_query.txt', 'dev_selection.txt')
train_data_dir = 'data/viquae_for_rrt'
test_data_dir  = 'data/viquae_for_rrt'
train_gnd_file = 'gnd_train.pkl'
test_gnd_file = 'gnd_dev.pkl'
desc_name = 'r50_gldv2'
sampler = 'triplet'
split_char  = ';;'
prefixed = 'non_humans'"""

"name = 'train_viquae_dev_r50_gldv2'\nset_name = 'train'\neval_set_name = 'dev'\ntrain_txt = ('train_query.txt', 'train_gallery.txt')\ntest_txt = ('dev_query.txt', 'dev_selection.txt')\ntrain_data_dir = 'data/viquae_for_rrt'\ntest_data_dir  = 'data/viquae_for_rrt'\ntrain_gnd_file = 'gnd_train.pkl'\ntest_gnd_file = 'gnd_dev.pkl'\ndesc_name = 'r50_gldv2'\nsampler = 'triplet'\nsplit_char  = ';;'\nprefixed = 'non_humans'"

In [23]:
batch_size      = 36
test_batch_size = 36
max_sequence_len = 500
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 10]

In [24]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.2
activation = "relu"
normalize_before = False
use_pos = False

In [25]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [26]:
torch.manual_seed(seed)
loaders, recall_ks = get_loaders(desc_name=desc_name,
    train_data_dir=train_data_dir, 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10], 
    set_name=set_name, eval_set_name=None,
    train_gnd_file=None, prefixed=prefixed, num_candidates=100)

In [27]:
nn_inds_path = loaders.set_name + '_nn_inds_%s.pkl'%loaders.query.dataset.desc_name
nn_inds_path = nn_inds_path if loaders.prefixed is None else loaders.prefixed+'_'+nn_inds_path
nn_inds_path = osp.join(loaders.query.dataset.data_dir, nn_inds_path)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

In [56]:
## Load Model
torch.manual_seed(seed+1)

model = get_model(num_global_features,
                  num_local_features,
                  seq_len,dim_K,
                  dim_feedforward,
                  nhead,
                  num_encoder_layers,
                  dropout,
                  activation,
                  normalize_before,
                  use_pos)

In [29]:
## Freeze Parametes
if classifier:
    #freeze all layers of the model
    for param in model.parameters():
        param.requires_grad = False

    #unfreeze the classfication layers
    for param in model.classifier.parameters():
        param.requires_grad = True

if transformer:
    #freeze all layers of the model
    for param in model.parameters():
        param.requires_grad = True

    for param in model.classifier.parameters():
        param.requires_grad = False
    for param in model.seg_encoder.parameters():
        param.requires_grad = False
    for param in model.scale_encoder.parameters():
        param.requires_grad = False
    for param in model.remap.parameters():
        param.requires_grad = False

if last_layers:
    #freeze all layers of the model
    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True
    for param in model.seg_encoder.parameters():
        param.requires_grad = True
    for param in model.scale_encoder.parameters():
        param.requires_grad = True
    for param in model.remap.parameters():
        param.requires_grad = True

In [30]:
if resume is not None:
    checkpoint = torch.load(resume, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state'], strict=True)
print('# of trainable parameters: ', num_of_trainable_params(model))
class_loss = get_loss(loss)

# of trainable parameters:  2243201


In [31]:
torch.manual_seed(seed+2)
model.to(device)
model = nn.DataParallel(model)
parameters = []
if no_bias_decay:
    parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]})
    parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0})
else:
    parameters.append({'params': model.parameters()})
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train),
                                               optim=optim, epochs=epochs, lr=lr,
                                               momentum=momentum, nesterov=nesterov, weight_decay=weight_decay,
                                               scheduler=scheduler, scheduler_tau=scheduler_tau, 
                                               scheduler_gamma=scheduler_gamma, lr_step=None)
if resume is not None and checkpoint.get('optim', None) is not None:
    optimizer.load_state_dict(checkpoint['optim'])
    del checkpoint


## Evaluation Utility Functions

In [32]:
def read_entry_once(
    query_loader: DataLoader,
    gallery_loader: DataLoader):
    
    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    print("READING ENTRIES FOR THE FIRST TIME")
    
    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))
            torch.cuda.empty_cache()

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            torch.cuda.empty_cache()

        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)
    
    query_feats   = [query_global, query_local, query_mask, query_scales, query_positions, query_names]
    gallery_feats = [gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names]
    
    return query_feats, gallery_feats

In [33]:
from utils.metrics import mean_average_precision_viquae_rerank

def fast_evaluate_viquae(
    model: nn.Module,
    cache_nn_inds: torch.Tensor,
    query_loader: DataLoader,
    gallery_loader: DataLoader,
    recall: List[int],
    query_feats, 
    gallery_feats):
    
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    
    if len(query_feats) == 0:
        query_feats, gallery_feats = read_entry_once(query_loader, gallery_loader)
    
    query_global, query_local, query_mask, query_scales, query_positions, query_names = query_feats
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = gallery_feats
    
    torch.cuda.empty_cache()
    
    fast_evaluate_function = partial(mean_average_precision_viquae_rerank, model=model, cache_nn_inds=cache_nn_inds,
        query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
        gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
        ks=recall, 
        gnd=query_loader.dataset.gnd_data,
    )
    metrics = fast_evaluate_function()
    
    return metrics, query_feats, gallery_feats

In [36]:
query_feats, gallery_feats = [], []

In [37]:
torch.manual_seed(seed+3)
# setup partial function to simplify call

eval_function = partial(fast_evaluate_viquae, model=model, 
    cache_nn_inds=cache_nn_inds,
    recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery,
    query_feats=query_feats, gallery_feats=gallery_feats)

In [38]:
result, query_feats, gallery_feats = eval_function()
pprint(result)
best_val = (0, result, deepcopy(model.state_dict()))

Extracting query features:   0%|                          | 0/2 [00:00<?, ?it/s]

READING ENTRIES FOR THE FIRST TIME


100%|██████████| 57/57 [00:33<00:00,  1.71it/s]                                 


{'map': 29.87, 'mrr': 39.06, 'precision': 11.82, 'hit_rate': 85.96, 'recall': 85.09, 'map@1': 11.9, 'mrr@1': 28.07, 'precision@1': 28.07, 'hit_rate@1': 28.07, 'recall@1': 11.9, 'map@5': 17.57, 'mrr@5': 36.08, 'precision@5': 17.54, 'hit_rate@5': 49.12, 'recall@5': 21.97, 'map@10': 19.82, 'mrr@10': 37.97, 'precision@10': 14.56, 'hit_rate@10': 61.4, 'recall@10': 28.76}
{'hit_rate': 85.96,
 'hit_rate@1': 28.07,
 'hit_rate@10': 61.4,
 'hit_rate@5': 49.12,
 'map': 29.87,
 'map@1': 11.9,
 'map@10': 19.82,
 'map@5': 17.57,
 'mrr': 39.06,
 'mrr@1': 28.07,
 'mrr@10': 37.97,
 'mrr@5': 36.08,
 'precision': 11.82,
 'precision@1': 28.07,
 'precision@10': 14.56,
 'precision@5': 17.54,
 'recall': 85.09,
 'recall@1': 11.9,
 'recall@10': 28.76,
 'recall@5': 21.97}


In [42]:
torch.manual_seed(seed+4)

<torch._C.Generator at 0x7f69d9c3d7b0>

In [43]:
save_name = osp.join(temp_dir, 'rrt_tuto_viquae_dev_r50_gldv2.pt')

In [44]:
result

{'map': 29.87,
 'mrr': 39.06,
 'precision': 11.82,
 'hit_rate': 85.96,
 'recall': 85.09,
 'map@1': 11.9,
 'mrr@1': 28.07,
 'precision@1': 28.07,
 'hit_rate@1': 28.07,
 'recall@1': 11.9,
 'map@5': 17.57,
 'mrr@5': 36.08,
 'precision@5': 17.54,
 'hit_rate@5': 49.12,
 'recall@5': 21.97,
 'map@10': 19.82,
 'mrr@10': 37.97,
 'precision@10': 14.56,
 'hit_rate@10': 61.4,
 'recall@10': 28.76}

# Training

In [60]:
def train_one_epoch(
        model: nn.Module,
        loader: DataLoader,
        class_loss: nn.Module,
        optimizer: Optimizer,
        # the second entry indicates if the scheduler should step per iteration or epoch
        scheduler: (_LRScheduler, bool), 
        max_norm: float,
        epoch: int,
        writer: SummaryWriter,
        # callback: VisdomLogger,
        freq: int,
        ex: Experiment = None) -> None:
    model.train()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    loader_length = len(loader)
    train_losses = AverageMeter(device=device, length=loader_length)
    train_accs = AverageMeter(device=device, length=loader_length)
    pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
    for i, entry in enumerate(pbar):
        global_feats, local_feats, local_mask, scales, positions, _, _ = entry
        global_feats, local_feats, local_mask, scales, positions = map(to_device, (global_feats, local_feats, local_mask, scales, positions))
        
        p_logits = model(global_feats[0::3], local_feats[0::3], local_mask[0::3], scales[0::3], positions[0::3],
            global_feats[1::3], local_feats[1::3], local_mask[1::3], scales[1::3], positions[1::3])
        n_logits = model(global_feats[0::3], local_feats[0::3], local_mask[0::3], scales[0::3], positions[0::3],
            global_feats[2::3], local_feats[2::3], local_mask[2::3], scales[2::3], positions[2::3])

        logits = torch.cat([p_logits, n_logits], 0)
        bsize = logits.size(0)
        # assert (bsize % 2 == 0)
        labels = logits.new_ones(logits.size()).float()
        labels[(bsize//2):] = 0
        loss = class_loss(logits, labels).mean()
        acc = ((torch.sigmoid(logits) > 0.5).long() == labels.long()).float().mean()

        ##############################################
        optimizer.zero_grad()
        loss.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

     #   if scheduler[-1]:
      #      scheduler[0].step()

        train_losses.append(loss)
        train_accs.append(acc)


        step = epoch + i / loader_length
        #print('step/loss/accu/lr:', step, train_losses.last_avg.item(), train_accs.last_avg.item(), scheduler[0].get_last_lr()[0])
        writer.add_scalar('loss/step', train_losses.last_avg.item(),  step)
        writer.add_scalar('accu/step', train_accs.last_avg.item(),    step)
        writer.add_scalar('lr/step',   scheduler[0].get_last_lr()[0], step)

  #  if not scheduler[-1]:
   #     scheduler[0].step()

    if ex is not None:
        for i, (loss, acc) in enumerate(zip(train_losses.values_list, train_accs.values_list)):
            step = epoch + i / loader_length
            ex.log_scalar('train.loss', loss, step=step)
            ex.log_scalar('train.acc',  acc, step=step)
    
    print('STATISTICS FOR EPOCH ', epoch)
    step = epoch + i / loader_length
 v   print('step/loss/accu/lr:', step, train_losses.last_avg.item(), train_accs.last_avg.item(), scheduler[0].get_last_lr()[0])
    writer.add_scalar('loss/epoch', train_losses.last_avg.item(),  epoch)
    writer.add_scalar('accu/epoch', train_accs.last_avg.item(),    epoch)
    writer.add_scalar('lr/epoch',   scheduler[0].get_last_lr()[0], epoch)



In [70]:
## Load Model
torch.manual_seed(seed+1)

model = get_model(num_global_features,
                  num_local_features,
                  seq_len,dim_K,
                  dim_feedforward,
                  nhead,
                  num_encoder_layers,
                  dropout,
                  activation,
                  normalize_before,
                  use_pos)

In [71]:
if resume is not None:
    checkpoint = torch.load(resume, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state'], strict=True)
print('# of trainable parameters: ', num_of_trainable_params(model))
class_loss = get_loss(loss)

# of trainable parameters:  2243201


In [72]:
torch.manual_seed(seed+2)
model.to(device)
model = nn.DataParallel(model)
parameters = []
if no_bias_decay:
    parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]})
    parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0})
else:
    parameters.append({'params': model.parameters()})
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train),
                                               optim=optim, epochs=epochs, lr=lr,
                                               momentum=momentum, nesterov=nesterov, weight_decay=weight_decay,
                                               scheduler=scheduler, scheduler_tau=scheduler_tau, 
                                               scheduler_gamma=scheduler_gamma, lr_step=None)
if resume is not None and checkpoint.get('optim', None) is not None:
    optimizer.load_state_dict(checkpoint['optim'])
    del checkpoint


In [74]:
gnd =  pickle_load(osp.join(train_data_dir, prefixed+'_'+train_gnd_file))

In [75]:
gnd['gnd'][10]['junk'], gnd['gnd'][10]['easy'], 

(['512px-Richard_Rodríguez_in_2017_(36917473891)_(cropped).jpg',
  '512px-Gil_Rondon_-_Houston_Astros_-_1976.jpg',
  '512px-Tommy_Hunter_on_August_10,_2011.jpg',
  '512px-Piniella.jpg',
  '512px-Mike_Trout_(6157725038).jpg',
  '512px-Brett_Myers_on_July_30,_2012.jpg',
  '512px-Tylocephale.jpg',
  '512px-Vp69_insig.jpg',
  '512px-Curtitoma_lawrenciana_001.jpg',
  '512px-VWA_Schwarz_WK1.jpg',
  '512px-Jason_Tyner.jpg',
  '512px-Cleveland_Indians_primary_logo.svg.png',
  '512px-Brandon_Snyder_on_May_7,_2012.jpg',
  '512px-Dave_Jageler_2010.jpg',
  '512px-Houston_Astros_cap_logo.svg.png',
  '512px-Orioles_pitcher_Jason_Garcia_in_2015.jpg',
  '512px-Joaquin_Benoit.jpg',
  '512px-AirAmericaPilotCap.jpg',
  '512px-Imposter_trevally.PNG',
  '512px-Scott_Moore_on_July_2,_2012.jpg',
  '512px-Francisco_Liriano_on_August_27,_2012.jpg',
  '512px-20130731-0309_Jonathan_Villar.jpg',
  '512px-Carangoides_bartholomaei.png',
  '512px-Juan_Centeno_on_June_8,_2015.jpg',
  '512px-Justin_Ruggiano_on_June_10

In [73]:
lr = 0.00001
writer = SummaryWriter()
for epoch in range(100):
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, True)

    torch.cuda.empty_cache()
    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, writer=writer, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)
    
    # validation
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, False)
    
    torch.cuda.empty_cache()
    """result, query_feats, gallery_feats = fast_evaluate_viquae(model=model,
                                                              cache_nn_inds=cache_nn_inds,
                                                              recall=recall_ks,
                                                              query_loader=loaders.query, 
                                                              gallery_loader=loaders.gallery,
                                                              query_feats=query_feats, 
                                                              gallery_feats=gallery_feats)
    
    print('Validation [{:03d}]'.format(epoch)), pprint(result)
    #ex.log_scalar('val.map', result['map'], step=epoch + 1)

    if result['map'] >= best_val[1]['map']:
        print('New best model in epoch %d.'%epoch)
        best_val = (epoch + 1, result, deepcopy(model.state_dict()))
        #torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)
    """

writer.flush()
writer.close()

Training   [000]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [001]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  0
step/loss/accu/lr: 0.5 1.6215338706970215 0.4285714626312256 1e-05


Training   [001]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [002]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  1
step/loss/accu/lr: 1.5 1.4147487878799438 0.5 1e-05


Training   [002]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [003]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  2
step/loss/accu/lr: 2.5 1.424875020980835 0.5 1e-05


Training   [003]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [004]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  3
step/loss/accu/lr: 3.5 1.1853445768356323 0.4285714626312256 1e-05


Training   [004]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [005]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  4
step/loss/accu/lr: 4.5 1.139065146446228 0.5714285969734192 1e-05


Training   [005]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.04it/s]
Training   [006]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  5
step/loss/accu/lr: 5.5 1.5126442909240723 0.4285714626312256 1e-05


Training   [006]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.05it/s]
Training   [007]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  6
step/loss/accu/lr: 6.5 1.04560124874115 0.6428571939468384 1e-05


Training   [007]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [008]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  7
step/loss/accu/lr: 7.5 1.2580307722091675 0.5 1e-05


Training   [008]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [009]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  8
step/loss/accu/lr: 8.5 0.7183050513267517 0.5714285969734192 1e-05


Training   [009]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [010]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  9
step/loss/accu/lr: 9.5 1.2300529479980469 0.4285714626312256 1e-05


Training   [010]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [011]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  10
step/loss/accu/lr: 10.5 1.120389461517334 0.4285714626312256 1e-05


Training   [011]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [012]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  11
step/loss/accu/lr: 11.5 1.262104868888855 0.3571428656578064 1e-05


Training   [012]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.04it/s]
Training   [013]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  12
step/loss/accu/lr: 12.5 1.2034233808517456 0.4285714626312256 1e-05


Training   [013]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [014]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  13
step/loss/accu/lr: 13.5 1.1507567167282104 0.5 1e-05


Training   [014]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [015]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  14
step/loss/accu/lr: 14.5 0.7018086910247803 0.5714285969734192 1e-05


Training   [015]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [016]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  15
step/loss/accu/lr: 15.5 0.9369977712631226 0.4285714626312256 1e-05


Training   [016]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [017]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  16
step/loss/accu/lr: 16.5 0.6191698908805847 0.7142857313156128 1e-05


Training   [017]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [018]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  17
step/loss/accu/lr: 17.5 1.0341030359268188 0.4285714626312256 1e-05


Training   [018]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.04it/s]
Training   [019]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  18
step/loss/accu/lr: 18.5 0.814167857170105 0.6428571939468384 1e-05


Training   [019]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [020]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  19
step/loss/accu/lr: 19.5 1.0765260457992554 0.3571428656578064 1e-05


Training   [020]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [021]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  20
step/loss/accu/lr: 20.5 0.9144597053527832 0.4285714626312256 1e-05


Training   [021]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [022]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  21
step/loss/accu/lr: 21.5 1.014511227607727 0.7142857313156128 1e-05


Training   [022]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [023]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  22
step/loss/accu/lr: 22.5 0.9838535189628601 0.5714285969734192 1e-05


Training   [023]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [024]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  23
step/loss/accu/lr: 23.5 0.8933554291725159 0.5714285969734192 1e-05


Training   [024]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [025]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  24
step/loss/accu/lr: 24.5 0.7912283539772034 0.6428571939468384 1e-05


Training   [025]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [026]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  25
step/loss/accu/lr: 25.5 1.3435410261154175 0.3571428656578064 1e-05


Training   [026]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [027]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  26
step/loss/accu/lr: 26.5 0.7364081144332886 0.5714285969734192 1e-05


Training   [027]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.04it/s]
Training   [028]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  27
step/loss/accu/lr: 27.5 0.8039392828941345 0.4285714626312256 1e-05


Training   [028]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [029]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  28
step/loss/accu/lr: 28.5 0.6972805261611938 0.5714285969734192 1e-05


Training   [029]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [030]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  29
step/loss/accu/lr: 29.5 0.8120235800743103 0.5714285969734192 1e-05


Training   [030]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [031]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  30
step/loss/accu/lr: 30.5 0.7269722819328308 0.4285714626312256 1e-05


Training   [031]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [032]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  31
step/loss/accu/lr: 31.5 0.6220663785934448 0.6428571939468384 1e-05


Training   [032]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [033]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  32
step/loss/accu/lr: 32.5 0.6355329751968384 0.5714285969734192 1e-05


Training   [033]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [034]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  33
step/loss/accu/lr: 33.5 0.8395267724990845 0.5 1e-05


Training   [034]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [035]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  34
step/loss/accu/lr: 34.5 1.1701263189315796 0.3571428656578064 1e-05


Training   [035]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [036]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  35
step/loss/accu/lr: 35.5 1.3416813611984253 0.1428571492433548 1e-05


Training   [036]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [037]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  36
step/loss/accu/lr: 36.5 0.7090498805046082 0.6428571939468384 1e-05


Training   [037]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [038]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  37
step/loss/accu/lr: 37.5 1.1308234930038452 0.2142857313156128 1e-05


Training   [038]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [039]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  38
step/loss/accu/lr: 38.5 1.0764905214309692 0.5 1e-05


Training   [039]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [040]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  39
step/loss/accu/lr: 39.5 0.6631113290786743 0.5714285969734192 1e-05


Training   [040]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [041]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  40
step/loss/accu/lr: 40.5 0.731777548789978 0.5 1e-05


Training   [041]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [042]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  41
step/loss/accu/lr: 41.5 1.4842009544372559 0.2857142984867096 1e-05


Training   [042]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [043]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  42
step/loss/accu/lr: 42.5 0.6503535509109497 0.6428571939468384 1e-05


Training   [043]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [044]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  43
step/loss/accu/lr: 43.5 0.7890868782997131 0.7142857313156128 1e-05


Training   [044]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [045]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  44
step/loss/accu/lr: 44.5 0.895501434803009 0.3571428656578064 1e-05


Training   [045]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.00s/it]
Training   [046]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  45
step/loss/accu/lr: 45.5 0.8790010213851929 0.5714285969734192 1e-05


Training   [046]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [047]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  46
step/loss/accu/lr: 46.5 0.794258713722229 0.5714285969734192 1e-05


Training   [047]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [048]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  47
step/loss/accu/lr: 47.5 0.7359521985054016 0.5 1e-05


Training   [048]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [049]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  48
step/loss/accu/lr: 48.5 0.9871084094047546 0.4285714626312256 1e-05


Training   [049]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [050]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  49
step/loss/accu/lr: 49.5 0.7269472479820251 0.6428571939468384 1e-05


Training   [050]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [051]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  50
step/loss/accu/lr: 50.5 0.9669936299324036 0.5714285969734192 1e-05


Training   [051]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [052]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  51
step/loss/accu/lr: 51.5 0.8156548142433167 0.5714285969734192 1e-05


Training   [052]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [053]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  52
step/loss/accu/lr: 52.5 0.9371192455291748 0.3571428656578064 1e-05


Training   [053]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [054]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  53
step/loss/accu/lr: 53.5 0.7723878026008606 0.6428571939468384 1e-05


Training   [054]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [055]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  54
step/loss/accu/lr: 54.5 1.026771068572998 0.2142857313156128 1e-05


Training   [055]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [056]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  55
step/loss/accu/lr: 55.5 0.6101655960083008 0.5 1e-05


Training   [056]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [057]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  56
step/loss/accu/lr: 56.5 0.8070975542068481 0.5714285969734192 1e-05


Training   [057]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [058]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  57
step/loss/accu/lr: 57.5 0.9162538051605225 0.5 1e-05


Training   [058]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [059]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  58
step/loss/accu/lr: 58.5 1.2092806100845337 0.2857142984867096 1e-05


Training   [059]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [060]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  59
step/loss/accu/lr: 59.5 0.5732988119125366 0.6428571939468384 1e-05


Training   [060]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [061]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  60
step/loss/accu/lr: 60.5 0.8002015352249146 0.5714285969734192 1e-05


Training   [061]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [062]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  61
step/loss/accu/lr: 61.5 0.9226334691047668 0.6428571939468384 1e-05


Training   [062]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [063]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  62
step/loss/accu/lr: 62.5 0.5477396249771118 0.785714328289032 1e-05


Training   [063]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [064]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  63
step/loss/accu/lr: 63.5 0.6722312569618225 0.6428571939468384 1e-05


Training   [064]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [065]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  64
step/loss/accu/lr: 64.5 0.7612463235855103 0.5714285969734192 1e-05


Training   [065]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.00s/it]
Training   [066]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  65
step/loss/accu/lr: 65.5 0.9548956155776978 0.3571428656578064 1e-05


Training   [066]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [067]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  66
step/loss/accu/lr: 66.5 0.7547075152397156 0.5 1e-05


Training   [067]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [068]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  67
step/loss/accu/lr: 67.5 0.5091008543968201 0.7142857313156128 1e-05


Training   [068]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [069]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  68
step/loss/accu/lr: 68.5 0.8256224393844604 0.4285714626312256 1e-05


Training   [069]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [070]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  69
step/loss/accu/lr: 69.5 0.7686232924461365 0.4285714626312256 1e-05


Training   [070]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.02it/s]
Training   [071]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  70
step/loss/accu/lr: 70.5 0.7785658240318298 0.5 1e-05


Training   [071]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.03it/s]
Training   [072]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  71
step/loss/accu/lr: 71.5 0.8169893622398376 0.5714285969734192 1e-05


Training   [072]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [073]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  72
step/loss/accu/lr: 72.5 1.037577509880066 0.3571428656578064 1e-05


Training   [073]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Training   [074]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  73
step/loss/accu/lr: 73.5 0.6855646967887878 0.7142857313156128 1e-05


Training   [074]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [075]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  74
step/loss/accu/lr: 74.5 0.7270674109458923 0.6428571939468384 1e-05


Training   [075]: 100%|███████████████████████████| 2/2 [00:01<00:00,  1.01it/s]
Training   [076]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  75
step/loss/accu/lr: 75.5 0.8589959740638733 0.4285714626312256 1e-05


Training   [076]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [077]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  76
step/loss/accu/lr: 76.5 1.0446467399597168 0.4285714626312256 1e-05


Training   [077]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [078]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  77
step/loss/accu/lr: 77.5 1.2050901651382446 0.3571428656578064 1e-05


Training   [078]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [079]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  78
step/loss/accu/lr: 78.5 0.6993643045425415 0.6428571939468384 1e-05


Training   [079]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [080]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  79
step/loss/accu/lr: 79.5 0.8614778518676758 0.3571428656578064 1e-05


Training   [080]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [081]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  80
step/loss/accu/lr: 80.5 0.8356241583824158 0.5 1e-05


Training   [081]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [082]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  81
step/loss/accu/lr: 81.5 0.6651732921600342 0.5714285969734192 1e-05


Training   [082]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [083]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  82
step/loss/accu/lr: 82.5 0.6636885404586792 0.5 1e-05


Training   [083]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [084]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  83
step/loss/accu/lr: 83.5 1.092616319656372 0.4285714626312256 1e-05


Training   [084]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [085]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  84
step/loss/accu/lr: 84.5 0.771196722984314 0.5714285969734192 1e-05


Training   [085]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [086]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  85
step/loss/accu/lr: 85.5 0.9302533268928528 0.2857142984867096 1e-05


Training   [086]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [087]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  86
step/loss/accu/lr: 86.5 0.8359860181808472 0.4285714626312256 1e-05


Training   [087]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.04s/it]
Training   [088]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  87
step/loss/accu/lr: 87.5 0.8054654598236084 0.3571428656578064 1e-05


Training   [088]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [089]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  88
step/loss/accu/lr: 88.5 0.8607926964759827 0.5 1e-05


Training   [089]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [090]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  89
step/loss/accu/lr: 89.5 0.8338539600372314 0.5714285969734192 1e-05


Training   [090]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [091]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  90
step/loss/accu/lr: 90.5 0.9473115801811218 0.2857142984867096 1e-05


Training   [091]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [092]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  91
step/loss/accu/lr: 91.5 0.6963279247283936 0.4285714626312256 1e-05


Training   [092]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.04s/it]
Training   [093]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  92
step/loss/accu/lr: 92.5 0.7513659000396729 0.5714285969734192 1e-05


Training   [093]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.04s/it]
Training   [094]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  93
step/loss/accu/lr: 93.5 0.7478512525558472 0.7142857313156128 1e-05


Training   [094]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [095]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  94
step/loss/accu/lr: 94.5 0.6597540974617004 0.5714285969734192 1e-05


Training   [095]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.04s/it]
Training   [096]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  95
step/loss/accu/lr: 95.5 0.8626301884651184 0.4285714626312256 1e-05


Training   [096]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]
Training   [097]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  96
step/loss/accu/lr: 96.5 1.0137149095535278 0.2857142984867096 1e-05


Training   [097]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Training   [098]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  97
step/loss/accu/lr: 97.5 0.8371903300285339 0.5714285969734192 1e-05


Training   [098]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.01s/it]
Training   [099]:   0%|                                   | 0/2 [00:00<?, ?it/s]

STATISTICS FOR EPOCH  98
step/loss/accu/lr: 98.5 1.004355788230896 0.5714285969734192 1e-05


Training   [099]: 100%|███████████████████████████| 2/2 [00:02<00:00,  1.02s/it]

STATISTICS FOR EPOCH  99
step/loss/accu/lr: 99.5 0.9489193558692932 0.3571428656578064 1e-05





## Test Triplet Sampler

In [40]:
gnd =  pickle_load(osp.join(train_data_dir, prefixed+'_'+train_gnd_file))

In [41]:
gnd['gnd'][0].keys()

dict_keys(['easy', 'hard', 'junk', 'neg', 'r_easy', 'r_hard', 'r_junk', 'r_neg', 'g_easy', 'g_hard', 'g_junk', 'g_neg', 'provenance_entity', 'ir_order', 'r_ir_order', 'rank_img_dict', 'img_rank_dict', 'anchor_idx', 'is_human'])

In [48]:
i = 3
len(gnd['gnd'][i]['junk']), len(gnd['gnd'][i]['easy']), len(gnd['gnd'][i]['g_junk']), len(gnd['gnd'][i]['g_easy'])

(35, 1, 35, 1)

In [None]:
positives = self.gnd_data[i]['r_easy']
negatives = self.gnd_data[i]['r_junk']
positive_inds = self.gnd_data[query_idx]['g_easy']
negative_inds = self.gnd_data[query_idx]['g_junk']

In [54]:
train_gnd_data  = None if train_gnd_file is None else pickle_load(osp.join(train_data_dir, train_gnd_file))
train_lines_txt = train_txt[1] if prefixed is None else prefixed+'_'+train_txt[1]
train_lines     = read_file(osp.join(train_data_dir, train_lines_txt))
train_q_lines_txt = train_txt[0] if prefixed is None else prefixed+'_'+train_txt[0]
train_q_lines   = read_file(osp.join(train_data_dir, train_q_lines_txt))
train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
train_q_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_q_lines]
train_set       = FeatureDataset(train_data_dir, train_samples,   desc_name, max_sequence_len, gnd_data=train_gnd_data)
query_train_set = FeatureDataset(train_data_dir, train_q_samples, desc_name, max_sequence_len, gnd_data=train_gnd_data)


In [55]:
train_g_lines   = read_file(osp.join(train_data_dir, prefixed+'_'+'train_gallery.txt'))

In [56]:
len(train_lines), len(train_q_lines), len(train_g_lines)

(5564, 57, 42678)

In [57]:
train_nn_inds = osp.join(train_data_dir, prefixed+'_'+set_name + '_nn_inds_%s.pkl'%desc_name)
gnd_data = train_set.gnd_data['gnd']
train_sampler = TripletSampler(query_train_set.targets, batch_size, train_nn_inds, num_candidates, gnd_data)

nn_inds_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/non_humans_tuto_nn_inds_r50_gldv2.pkl
labels len:  57


In [58]:
for item in train_sampler:
    print(item)

[2326, 609, 3627, 4026, 5127, 3215, 838, 3761, 3529, 5509, 1796, 863, 4091, 2071, 428, 1551, 97, 723, 543, 5348, 3246, 4696, 3439, 137, 11, 2465, 2966, 62, 2398, 1083, 2238, 2517, 1207, 4274, 1625, 400]
[4609, 2698, 1464, 3850, 1412, 3924, 3329, 1159, 5429, 2238, 5242, 999, 3423, 4198, 187, 592, 395, 5468, 1205, 4234, 1351, 961, 755, 661, 592, 5417, 4176, 438, 1614, 3960, 2295, 4781, 5520, 2884, 402, 4977]
[629, 4676, 906, 861, 2691, 4718, 2947, 3456, 2642, 4752, 314, 2274, 4754, 4384, 2106, 3010, 5470, 4222, 690, 1385, 2036, 4023, 4100, 4063, 3016, 1296, 2287, 3423, 3395, 81, 605, 1958, 179, 3918, 2675, 1415]
[2236, 2876, 822, 3342, 474, 756, 4696, 5502, 2288, 302, 1638, 3356, 1533, 2078, 1407, 929, 1128, 3816, 1551, 531, 4749, 2765, 2796, 4522, 4992, 450, 247, 4253, 5199, 1305, 5396, 5433, 713, 275, 511, 2062]
[690, 3193, 1500, 2493, 1402, 1459, 1720, 4395, 1601, 876, 1444, 445, 3315, 3708, 1867, 4685, 4104, 386]
