In [1]:
import numpy as np

In [2]:
import os, math
import os.path as osp
from copy import deepcopy
from functools import partial
from pprint import pprint

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from torch.backends import cudnn
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm

from torch.backends import cudnn
# from visdom_logger import VisdomLogger

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT
from models.ingredient import model_ingredient, get_model

In [6]:
from models.ingredient import model_ingredient, get_model
from utils import state_dict_to_cpu, num_of_trainable_params
from utils import pickle_load, pickle_save
from utils.data.utils import TripletSampler
from utils import BinaryCrossEntropyWithLogits
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.training import train_one_epoch, evaluate_viquae



In [7]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset

In [8]:
ex = sacred.Experiment('RRT Training', ingredients=[data_ingredient, model_ingredient], interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [9]:
epochs = 15
lr = 0.0001
momentum = 0.
nesterov = False
weight_decay = 5e-4
optim = 'adamw'
scheduler = 'multistep'
max_norm = 0.0
seed = 0

visdom_port = None
visdom_freq = 100
cpu = False  # Force training on CPU
cudnn_flag = 'benchmark'
temp_dir = osp.join('outputs', 'temp')

no_bias_decay = False
loss = 'bce'
scheduler_tau = [16, 18]
scheduler_gamma = 0.1

resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt'

In [10]:
epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, max_norm, resume

(15,
 False,
 'benchmark',
 None,
 100,
 'outputs/temp',
 0,
 False,
 0.0,
 '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt')

In [11]:
def get_optimizer_scheduler(parameters, optim, loader_length, epochs, lr, momentum, nesterov, weight_decay, scheduler, scheduler_tau, scheduler_gamma, lr_step=None):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True if nesterov and momentum else False)
    elif optim == 'adam':
        optimizer = Adam(parameters, lr=lr, weight_decay=weight_decay) 
    else:
        optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay)
    
    if epochs == 0:
        scheduler = None
        update_per_iteration = None
    elif scheduler == 'cos':
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * loader_length, eta_min=0.000005)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.000001)
        update_per_iteration = False
    elif scheduler == 'warmcos':
        # warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / (epochs * loader_length))) / 2)
        warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / epochs)) / 2)
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_cosine)
        update_per_iteration = False
    elif scheduler == 'multistep':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_tau, gamma=scheduler_gamma)
        update_per_iteration = False
    elif scheduler == 'warmstep':
        warm_step = lambda i: min((i + 1) / 100, 1) * 0.1 ** (i // (lr_step * loader_length))
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_step)
        update_per_iteration = True
    else:
        scheduler = lr_scheduler.StepLR(optimizer, epochs * loader_length)
        update_per_iteration = True

    return optimizer, (scheduler, update_per_iteration)



In [12]:
def get_loss(loss):
    if loss == 'bce':
        return BinaryCrossEntropyWithLogits()
    else:
        raise Exception('Unsupported loss {}'.format(loss))

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
# callback = VisdomLogger(port=visdom_port) if visdom_port else None
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [14]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [15]:
name = 'tuto_viquae_tuto_r50_gldv2'
set_name = 'tuto'
train_txt = 'tuto.txt'
test_txt = ('tuto_query.txt', 'tuto_selection.txt')
# train_data_dir = 'data/viquae_for_rrt'
# test_data_dir  = 'data/viquae_for_rrt'
test_gnd_file = 'gnd_tuto.pkl'
desc_name = 'r50_gldv1'
sampler = 'triplet'
split_char  = ';;'

In [16]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, test_gnd_file, 
        max_sequence_len, split_char):
    ####################################################################################################################################
    train_lines     = read_file(osp.join(train_data_dir, train_txt))
    train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in query_lines]
    gallery_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)


In [17]:
def get_loaders(desc_name, train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name,
    num_candidates=100):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name=desc_name, 
        train_data_dir=train_data_dir, test_data_dir=train_data_dir, 
        train_txt=train_txt, test_txt=test_txt, test_gnd_file=test_gnd_file, 
        max_sequence_len=max_sequence_len, split_char=split_char)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        s_name = set_name
        if s_name != '':
            s_name = set_name + '_'
        train_nn_inds = osp.join(train_data_dir, s_name+'nn_inds_%s.pkl'%desc_name)
        train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories),set_name=set_name), recalls


In [18]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [19]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    set_name: str = ''
    gallery: Optional[DataLoader] = None

In [20]:
batch_size      = 32
test_batch_size = 32
max_sequence_len = 500
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 6, 10]

In [21]:
torch.manual_seed(seed)
loaders, recall_ks = get_loaders('r50_gldv1', train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name,
    num_candidates=100)

cache_nn_inds len:  100
cache_nn_inds len:  100
labels len:  10100


In [22]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before):
    return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [23]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.0 
activation = "relu"
normalize_before = False

In [24]:
torch.manual_seed(seed+1)
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before)

In [25]:
if resume is not None:
    checkpoint = torch.load(resume, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state'], strict=True)
print('# of trainable parameters: ', num_of_trainable_params(model))
class_loss = get_loss(loss)
nn_inds_path = osp.join(loaders.query.dataset.data_dir, loaders.set_name + '_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

# of trainable parameters:  2243201


In [26]:
torch.manual_seed(seed+2)
model.to(device)
model = nn.DataParallel(model)
parameters = []
if no_bias_decay:
    parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]})
    parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0})
else:
    parameters.append({'params': model.parameters()})
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train),
                                               optim=optim, epochs=epochs, lr=lr,
                                               momentum=momentum, nesterov=nesterov, weight_decay=weight_decay,
                                               scheduler=scheduler, scheduler_tau=scheduler_tau, 
                                               scheduler_gamma=scheduler_gamma, lr_step=None)
if resume is not None and checkpoint.get('optim', None) is not None:
    optimizer.load_state_dict(checkpoint['optim'])
    del checkpoint


In [27]:
torch.manual_seed(seed+3)
# setup partial function to simplify call
eval_function = partial(evaluate_viquae, model=model, 
    cache_nn_inds=cache_nn_inds,
    recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery)

In [28]:
# result = eval_function()
# pprint(result)
# best_val = (0, result, deepcopy(model.state_dict()))

In [29]:
torch.manual_seed(seed+4)

<torch._C.Generator at 0x7f1c83ef5810>

In [30]:
"""
for epoch in range(1):
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, True)

    torch.cuda.empty_cache()
    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)
    
    # validation
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, False)
    
    torch.cuda.empty_cache()
    result = eval_function()
    
    print('Validation [{:03d}]'.format(epoch)), pprint(result)
    #ex.log_scalar('val.map', result['map'], step=epoch + 1)

    if result['map'] >= best_val[1]['map']:
        print('New best model in epoch %d.'%epoch)
        best_val = (epoch + 1, result, deepcopy(model.state_dict()))
        torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)
"""

"\nfor epoch in range(1):\n    if cudnn_flag == 'benchmark':\n        setattr(cudnn, cudnn_flag, True)\n\n    torch.cuda.empty_cache()\n    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)\n    \n    # validation\n    if cudnn_flag == 'benchmark':\n        setattr(cudnn, cudnn_flag, False)\n    \n    torch.cuda.empty_cache()\n    result = eval_function()\n    \n    print('Validation [{:03d}]'.format(epoch)), pprint(result)\n    #ex.log_scalar('val.map', result['map'], step=epoch + 1)\n\n    if result['map'] >= best_val[1]['map']:\n        print('New best model in epoch %d.'%epoch)\n        best_val = (epoch + 1, result, deepcopy(model.state_dict()))\n        torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)\n"

In [31]:
train_lines     = read_file(osp.join(train_data_dir, train_txt))
train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)

In [32]:
s_name = set_name
if s_name != '':
    s_name = set_name + '_'
train_nn_inds = osp.join(train_data_dir, s_name+'nn_inds_%s.pkl'%desc_name)
train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates)

cache_nn_inds len:  100
cache_nn_inds len:  100
labels len:  10100


In [33]:
for item in train_sampler:
    print(item)

In [34]:
len(train_sampler)

0