# Imports

In [1]:
import numpy as np
from functools import partial

In [2]:
from copy import deepcopy
from functools import partial
from pprint import pprint
import os.path as osp

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from torch.backends import cudnn

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT
from models.ingredient import model_ingredient, get_model

In [6]:
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from sacred import Ingredient

In [7]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset
from utils.metrics import *



In [8]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset
ex = sacred.Experiment('RRT Evaluation', ingredients=[data_ingredient, model_ingredient], interactive=True)

In [9]:
ex = sacred.Experiment('Prepare Top-K (VIQUAE FOR RTT)', interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [10]:
feature_name = 'r50_gldv1'
set_name = 'tuto'
gnd_name = 'gnd_'+ set_name+'.pkl'

In [11]:
cpu = False
cudnn_flag = 'benchmark'
temp_dir = osp.join('logs', 'temp')
resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt'
seed = 0

In [12]:
dataset_name = 'viquae_for_rrt'
data_dir = osp.join('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data', dataset_name)

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
torch.manual_seed(seed)

<torch._C.Generator at 0x7f9a773d57f0>

In [14]:
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [15]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [16]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, test_gnd_file, 
        max_sequence_len=500):
    ####################################################################################################################################
    train_lines     = read_file(osp.join(train_data_dir, train_txt))
    train_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in train_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in query_lines]
    gallery_samples = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)

In [17]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    gallery: Optional[DataLoader] = None

In [18]:
def get_loaders(desc_name, train_data_dir, 
    batch_size=8, test_batch_size=8, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name, 
        train_data_dir=train_data_dir,
        test_data_dir=train_data_dir,
        train_txt=set_name+'_query.txt',
        test_txt=(set_name+'_query.txt', set_name+'_selection.txt'),
        test_gnd_file=gnd_name, 
        max_sequence_len=500)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        train_nn_inds = osp.join(train_data_dir, set_name+'_nn_inds_%s.pkl'%desc_name)
        train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories)), recalls

In [19]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [20]:
#train_txt = set_name+'_query.txt'
#query_lines     = read_file(osp.join(train_data_dir, train_txt))
#train_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in train_lines]


In [21]:
#test_txt = ('test_query.txt', 'test_selection.txt')
#desc_name = 'r50_gldv1'
#query_lines   = read_file(osp.join(train_data_dir, test_txt[0]))
#gallery_lines = read_file(osp.join(train_data_dir, test_txt[1]))
#query_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in query_lines]
#gallery_samples = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in gallery_lines]
#gallery_set = FeatureDataset(train_data_dir, gallery_samples, desc_name, max_sequence_len)
#query_set   = FeatureDataset(train_data_dir, query_samples,   desc_name, max_sequence_len)

In [22]:
#gallery_loader

In [23]:
def mean_average_precision_viquae_rerank(
    model: nn.Module,
    cache_nn_inds: torch.Tensor,
    query_global: torch.Tensor, query_local: torch.Tensor, query_mask: torch.Tensor, query_scales: torch.Tensor, query_positions: torch.Tensor,
    gallery_global: torch.Tensor, gallery_local: torch.Tensor, gallery_mask: torch.Tensor, gallery_scales: torch.Tensor, gallery_positions: torch.Tensor,
    ks: List[int],
    gnd) -> Dict[str, float]:

    device = next(model.parameters()).device
    query_global    = query_global.to(device)
    query_local     = query_local.to(device)
    query_mask      = query_mask.to(device)
    query_scales    = query_scales.to(device)
    query_positions = query_positions.to(device)

    num_samples, top_k = cache_nn_inds.size()
    top_k = min(100, top_k)
    
    gallery_global = gallery_global.reshape(query_global.size(dim=0), 100, query_global.size(dim=1))
    gallery_local = gallery_local.reshape(query_local.size(dim=0), 100, query_local.size(dim=1), query_local.size(dim=2))
    gallery_mask = gallery_mask.reshape(query_mask.size(dim=0), 100, query_mask.size(dim=1))
    gallery_scales = gallery_scales.reshape(query_scales.size(dim=0), 100, query_scales.size(dim=1))
    gallery_positions = gallery_positions.reshape(query_positions.size(dim=0), 100, query_positions.size(dim=1), query_positions.size(dim=2))


    ########################################################################################
    ## Evaluation
    eval_nn_inds = deepcopy(cache_nn_inds.cpu().data.numpy())

    # Exclude the junk images as in DELG (https://github.com/tensorflow/models/blob/44cad43aadff9dd12b00d4526830f7ea0796c047/research/delf/delf/python/detect_to_retrieve/image_reranking.py#L190)
    for i in range(num_samples):
        junk_ids = gnd['gnd'][i]['junk']
        all_ids = eval_nn_inds[i]
        pos = np.in1d(all_ids, junk_ids)
        neg = np.array([not x for x in pos])
        new_ids = np.concatenate([np.arange(len(all_ids))[neg], np.arange(len(all_ids))[pos]])
        new_ids = all_ids[new_ids]
        eval_nn_inds[i] = new_ids
    eval_nn_inds = torch.from_numpy(eval_nn_inds)
    
    scores = []
    for i in tqdm(range(top_k)):
        nnids = eval_nn_inds[:, i]
        topk_scores =  []
        for iterator in range(nnids.size(dim=0)):
            index_global = [gallery_global[iterator, nnids[iterator]]]
            index_local = [gallery_local[iterator, nnids[iterator]]]
            index_mask = [gallery_mask[iterator, nnids[iterator]]]
            index_scales = [gallery_scales[iterator, nnids[iterator]]]
            index_positions = [gallery_positions[iterator, nnids[iterator]]]

            torch.cuda.empty_cache()
            index_global = torch.from_numpy(np.stack(index_global, axis=0))

            torch.cuda.empty_cache()
            index_local = torch.from_numpy(np.stack(index_local, axis=0))

            torch.cuda.empty_cache()
            index_mask = torch.from_numpy(np.stack(index_mask, axis=0))

            torch.cuda.empty_cache()
            index_scales = torch.from_numpy(np.stack(index_scales, axis=0))

            torch.cuda.empty_cache()
            index_positions = torch.from_numpy(np.stack(index_positions, axis=0))
            
            q_global = query_global[iterator].unsqueeze(dim=0)
            q_local = query_local[iterator].unsqueeze(dim=0)
            q_mask = query_mask[iterator].unsqueeze(dim=0)
            q_scales = query_scales[iterator].unsqueeze(dim=0)
            q_positions = query_positions[iterator].unsqueeze(dim=0)

            iter_scores = model(
            q_global, q_local, q_mask, q_scales, q_positions,
                index_global.to(device),
                index_local.to(device),
                index_mask.to(device),
                index_scales.to(device),
                index_positions.to(device))
            
            topk_scores.append(iter_scores.cpu().data)
        
        current_scores = torch.from_numpy(np.stack(topk_scores, axis=0)).squeeze(1)
        torch.cuda.empty_cache()        
        scores.append(current_scores.cpu().data)
    
    
    scores = torch.stack(scores, -1) # nb_queries x 100
    closest_dists, indices = torch.sort(scores, dim=-1, descending=True)
    closest_indices = torch.gather(eval_nn_inds, -1, indices)
    ranks = deepcopy(eval_nn_inds)
    ranks[:, :top_k] = deepcopy(closest_indices)
    ranks = ranks.cpu().data.numpy().T
    # pickle_save('eval_nn_inds.pkl', ranks.T)
    out = compute_metrics('viquae', ranks, gnd['gnd'], kappas=ks)

    ########################################################################################  
    
    return out

In [24]:
(train_set, query_train_set), (query_set, gallery_set) = get_sets('r50_gldv1',
            train_data_dir,
            train_data_dir,
            set_name+'_query.txt',
            (set_name+'_query.txt',set_name+'_selection.txt'),
            'gnd_'+set_name+'.pkl',
            500)

In [25]:
batch_size      = 16
test_batch_size = 16
max_sequence_len = 100
sampler = 'random'
if sampler == 'random':
   train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)

In [26]:
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 10]

In [27]:
train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [28]:
query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [29]:
loaders, recall_ks = get_loaders(desc_name='r50_gldv1',
    train_data_dir='/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt', 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100)

In [30]:
model_ingredient = Ingredient('model', interactive=True)

In [31]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before):
    return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [32]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.0 
activation = "relu"
normalize_before = False

In [33]:
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before)

In [34]:
if resume is not None:
   checkpoint = torch.load(resume, map_location=torch.device('cpu'))
   model.load_state_dict(checkpoint['state'], strict=True)

In [35]:
model.to(device)
model.eval()

MatchERT(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Lin

In [36]:
loaders.query.dataset.desc_name, loaders.query.dataset.desc_name, loaders.query.dataset.data_dir

('r50_gldv1',
 'r50_gldv1',
 '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt')

In [37]:
loaders.query

<torch.utils.data.dataloader.DataLoader at 0x7f9a18cb6290>

In [38]:
nn_inds_path = osp.join(loaders.query.dataset.data_dir, set_name+'_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

In [39]:
cache_nn_inds.size()

torch.Size([100, 100])

In [40]:
num_samples, top_k = cache_nn_inds.size()
top_k = min(top_k, 100)
top_k

100

In [41]:
def evaluate_viquae(
        model: nn.Module,
        cache_nn_inds: torch.Tensor,
        query_loader: DataLoader,
        gallery_loader: DataLoader,
        recall: List[int]):
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)

    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))
            torch.cuda.empty_cache()

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            torch.cuda.empty_cache()
            
        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)

        torch.cuda.empty_cache()
        evaluate_function = partial(mean_average_precision_viquae_rerank, model=model, cache_nn_inds=cache_nn_inds,
            query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
            gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
            ks=recall, 
            gnd=query_loader.dataset.gnd_data,
        )
        metrics = evaluate_function()
    return metrics 

In [44]:
model.eval()
device = next(model.parameters()).device
to_device = lambda x: x.to(device, non_blocking=True)

query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

In [45]:
for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
    q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
    query_global.append(q_global.cpu())
    query_local.append(q_local.cpu())
    query_mask.append(q_mask.cpu())
    query_scales.append(q_scales.cpu())
    query_positions.append(q_positions.cpu())
    query_names.extend(list(q_names))
    torch.cuda.empty_cache()

                                                                                

In [46]:
query_global    = torch.cat(query_global, 0)
query_local     = torch.cat(query_local, 0)
query_mask      = torch.cat(query_mask, 0)
query_scales    = torch.cat(query_scales, 0)
query_positions = torch.cat(query_positions, 0)

In [47]:
for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
    g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
    gallery_global.append(g_global.cpu())
    gallery_local.append(g_local.cpu())
    gallery_mask.append(g_mask.cpu())
    gallery_scales.append(g_scales.cpu())
    gallery_positions.append(g_positions.cpu())
    gallery_names.extend(list(g_names))
    torch.cuda.empty_cache()

                                                                                

In [48]:
gallery_global    = torch.cat(gallery_global, 0)
gallery_local     = torch.cat(gallery_local, 0)
gallery_mask      = torch.cat(gallery_mask, 0)
gallery_scales    = torch.cat(gallery_scales, 0)
gallery_positions = torch.cat(gallery_positions, 0)

In [49]:
gallery_global.shape, gallery_local.shape, gallery_mask.shape, gallery_scales.shape, gallery_positions.shape

(torch.Size([4838, 2048]),
 torch.Size([4838, 500, 128]),
 torch.Size([4838, 500]),
 torch.Size([4838, 500]),
 torch.Size([4838, 500, 2]))

In [50]:
shape = list(gallery_global.shape)
shape[1:1] = [100]
shape

[4838, 100, 2048]

In [51]:
shape = list(gallery_global.shape)
shape.insert(1, 100)
shape

[4838, 100, 2048]

In [52]:
torch.zeros(shape).size(dim=0)

4838

In [53]:
gallery_local[0]

tensor([[ 0.0181,  0.1067, -0.0501,  ..., -0.0032, -0.0262,  0.0568],
        [ 0.0064,  0.0597, -0.1070,  ..., -0.0430, -0.0586, -0.0196],
        [ 0.0378, -0.0858,  0.0433,  ..., -0.0163,  0.1774,  0.0182],
        ...,
        [ 0.0129,  0.1898, -0.0403,  ...,  0.0155,  0.0798, -0.0445],
        [-0.0825, -0.0737, -0.1378,  ...,  0.0845, -0.0098,  0.1156],
        [-0.0951, -0.0647, -0.2040,  ..., -0.0597,  0.0242, -0.0652]])

In [54]:
device = next(model.parameters()).device
query_global    = query_global.to(device)
query_local     = query_local.to(device)
query_mask      = query_mask.to(device)
query_scales    = query_scales.to(device)
query_positions = query_positions.to(device)

num_samples, top_k = cache_nn_inds.size()
top_k = min(100, top_k)

In [55]:
gnd_data = pickle_load(osp.join(data_dir, gnd_name))
gnd = pickle_load(osp.join(data_dir, gnd_name))

In [56]:
_sizes = [len(gnd_data['simlist'][i]) for i in range(len(gnd_data['simlist']))]
np.sum(_sizes)

4838

In [57]:
def fill_in_and_pad(gallery_in, query, sizes):
    shape = list(query.shape)
    shape.insert(1, 100)
    gallery_out = torch.zeros(shape)
    #print('gallery_out.shape: ', gallery_out.shape)
    size = 0
    counter = 0
    for i in range(gallery_out.size(dim=0)):
        for j in range(gallery_out.size(dim=1)):
            #print(counter)
            if j < sizes[i]:
                gallery_out[i][j] = gallery_in[counter]
                counter += 1
    return gallery_out

In [58]:
gallery_global    = fill_in_and_pad(gallery_global, query_global, _sizes)
gallery_local     = fill_in_and_pad(gallery_local, query_local, _sizes)
gallery_mask      = fill_in_and_pad(gallery_mask, query_mask, _sizes)
gallery_scales    = fill_in_and_pad(gallery_scales, query_scales, _sizes)
gallery_positions = fill_in_and_pad(gallery_positions, query_positions, _sizes)

In [59]:
eval_nn_inds = deepcopy(cache_nn_inds.cpu().data.numpy())

In [60]:
# Exclude the junk images as in DELG (https://github.com/tensorflow/models/blob/44cad43aadff9dd12b00d4526830f7ea0796c047/research/delf/delf/python/detect_to_retrieve/image_reranking.py#L190)
for i in range(num_samples):
    junk_ids = gnd['gnd'][i]['r_neg']
    all_ids = eval_nn_inds[i]
    pos = np.in1d(all_ids, junk_ids)
    neg = np.array([not x for x in pos])
    new_ids = np.concatenate([np.arange(len(all_ids))[neg], np.arange(len(all_ids))[pos]])
    new_ids = all_ids[new_ids]
    eval_nn_inds[i] = new_ids
eval_nn_inds = torch.from_numpy(eval_nn_inds)

In [61]:
nnids = eval_nn_inds[:, 0]
iterator = 0
print('iterator: ', iterator)
index_global = gallery_global[iterator, nnids[iterator]]
index_local = gallery_local[iterator, nnids[iterator]]
index_mask = gallery_mask[iterator, nnids[iterator]]
index_scales = gallery_scales[iterator, nnids[iterator]]
index_positions = gallery_positions[iterator, nnids[iterator]]

index_global = index_global.unsqueeze(dim=0)
index_global = index_global.type(torch.float32)

index_local = index_local.unsqueeze(dim=0)
index_local = index_local.type(torch.float32)

index_mask = index_mask.unsqueeze(dim=0)
index_mask = index_mask.type(torch.bool)

index_scales = index_scales.unsqueeze(dim=0)
index_scales = index_scales.type(torch.int64)

index_positions = index_positions.unsqueeze(dim=0)
index_positions = index_positions.type(torch.float32)

q_global = query_global[iterator].unsqueeze(dim=0)
q_local = query_local[iterator].unsqueeze(dim=0)
q_mask = query_mask[iterator].unsqueeze(dim=0)
q_scales = query_scales[iterator].unsqueeze(dim=0)
q_positions = query_positions[iterator].unsqueeze(dim=0)

iter_scores = model(
q_global, q_local, q_mask, q_scales, q_positions,
    index_global.to(device),
    index_local.to(device),
    index_mask.to(device),
    index_scales.to(device),
    index_positions.to(device))

iterator:  0


In [62]:
iter_scores

tensor([-3.5857], device='cuda:0', grad_fn=<ViewBackward>)

In [137]:
scores = []
for i in tqdm(range(top_k)):
    nnids = eval_nn_inds[:, i]
    topk_scores =  []
    for iterator in range(nnids.size(dim=0)):
        index_global = gallery_global[iterator, nnids[iterator]]
        index_local = gallery_local[iterator, nnids[iterator]]
        index_mask = gallery_mask[iterator, nnids[iterator]]
        index_scales = gallery_scales[iterator, nnids[iterator]]
        index_positions = gallery_positions[iterator, nnids[iterator]]

        index_global = index_global.unsqueeze(dim=0)
        index_global = index_global.type(torch.float32)

        index_local = index_local.unsqueeze(dim=0)
        index_local = index_local.type(torch.float32)

        index_mask = index_mask.unsqueeze(dim=0)
        index_mask = index_mask.type(torch.bool)

        index_scales = index_scales.unsqueeze(dim=0)
        index_scales = index_scales.type(torch.int64)

        index_positions = index_positions.unsqueeze(dim=0)
        index_positions = index_positions.type(torch.float32)

        q_global = query_global[iterator].unsqueeze(dim=0)
        q_local = query_local[iterator].unsqueeze(dim=0)
        q_mask = query_mask[iterator].unsqueeze(dim=0)
        q_scales = query_scales[iterator].unsqueeze(dim=0)
        q_positions = query_positions[iterator].unsqueeze(dim=0)
        
        
        iter_scores = model(
        q_global, q_local, q_mask, q_scales, q_positions,
            index_global.to(device),
            index_local.to(device),
            index_mask.to(device),
            index_scales.to(device),
            index_positions.to(device))

        topk_scores.append(iter_scores.cpu().data)

    current_scores = torch.from_numpy(np.stack(topk_scores, axis=0)).squeeze(1)
    torch.cuda.empty_cache()        
    scores.append(current_scores.cpu().data)

100%|██████████| 100/100 [01:05<00:00,  1.54it/s]


In [138]:
scores = torch.stack(scores, -1) # nb_queries x 100
closest_dists, indices = torch.sort(scores, dim=-1, descending=True)
closest_indices = torch.gather(eval_nn_inds, -1, indices)
ranks = deepcopy(eval_nn_inds)
ranks[:, :top_k] = deepcopy(closest_indices)
ranks = ranks.cpu().data.numpy().T
# pickle_save('eval_nn_inds.pkl', ranks.T)
out = compute_metrics('viquae', ranks, gnd['gnd'], kappas=[1, 5, 10])

{'map': 8.78, 'map@1': 2.07, 'map@5': 4.25, 'map@10': 5.97, 'mrr': 13.09, 'mrr@1': 11.0, 'mrr@5': 12.75, 'mrr@10': 13.03, 'precision': 2.18, 'precision@1': 11.0, 'precision@5': 8.2, 'precision@10': 7.5}


In [None]:
torch.cuda.empty_cache()
evaluate_function = partial(mean_average_precision_viquae_rerank, model=model, cache_nn_inds=cache_nn_inds,
    query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
    gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
    ks=recalls, 
    gnd=query_loader.dataset.gnd_data,
)
metrics = evaluate_function()

In [None]:
query_global.shape, gallery_global.shape

In [None]:
#eval_function = partial(evaluate_viquae, model=model, 
#        cache_nn_inds=cache_nn_inds,
#        recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery)

In [None]:
#metrics = eval_function()
#pprint(metrics)
#best_val = (0, metrics, deepcopy(model.state_dict()))