# Imports

In [1]:
import numpy as np
from functools import partial

In [2]:
from copy import deepcopy
from functools import partial
from pprint import pprint
import os.path as osp

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from torch.backends import cudnn

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT
from models.ingredient import model_ingredient, get_model


In [6]:
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from sacred import Ingredient

In [7]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset
# from visdom_logger import VisdomLogger
from utils.metrics import *

In [8]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset
# from utils.training import evaluate_time as evaluate
#from utils.training import evaluate
ex = sacred.Experiment('RRT Evaluation', ingredients=[data_ingredient, model_ingredient], interactive=True)

In [9]:
ex = sacred.Experiment('Prepare Top-K (VIQUAE FOR RTT)', interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [10]:
feature_name = 'r50_gldv1'
set_name = 'test'
gnd_name = 'gnd_'+ set_name+'.pkl'

In [11]:
cpu = False
cudnn_flag = 'benchmark'
temp_dir = osp.join('logs', 'temp')
resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt'
seed = 0

In [12]:
dataset_name = 'viquae_for_rrt'
data_dir = osp.join('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data', dataset_name)

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
torch.manual_seed(seed)

<torch._C.Generator at 0x7f4ba45897d0>

In [14]:
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [15]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [16]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, test_gnd_file, 
        max_sequence_len=500):
    ####################################################################################################################################
    train_lines     = read_file(osp.join(train_data_dir, train_txt))
    train_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in train_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in query_lines]
    gallery_samples = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)

In [17]:
class MetricLoaders(DataLoader):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    gallery: Optional[DataLoader] = None

In [18]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    gallery: Optional[DataLoader] = None

In [19]:
def get_loaders(desc_name, train_data_dir, 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name, 
        train_data_dir=train_data_dir,
        test_data_dir=train_data_dir,
        train_txt='test_query.txt',
        test_txt=('test_query.txt', 'test_gallery.txt'),
        test_gnd_file='gnd_test.pkl', 
        max_sequence_len=500)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        train_nn_inds = osp.join(train_data_dir, 'nn_inds_%s.pkl'%desc_name)
        train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories)), recalls

In [20]:
(train_set, query_train_set), (query_set, gallery_set) = get_sets('r50_gldv1',
            '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt',
            '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt',
            set_name+'_query.txt',
            (set_name+'_query.txt',set_name+'_gallery.txt'),
            'gnd_'+set_name+'.pkl',
            500)

In [21]:
batch_size      = 16
test_batch_size = 16
max_sequence_len = 500
sampler = 'random'
if sampler == 'random':
   train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)

In [22]:
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 6, 10]

In [23]:
train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [24]:
query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [25]:
loaders, recall_ks = get_loaders(desc_name='r50_gldv1',
    train_data_dir='/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt', 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100)

In [26]:
model_ingredient = Ingredient('model', interactive=True)

In [27]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before):
    return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [28]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.0 
activation = "relu"
normalize_before = False

In [29]:
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before)

In [30]:
if resume is not None:
   checkpoint = torch.load(resume, map_location=torch.device('cpu'))
   model.load_state_dict(checkpoint['state'], strict=True)

In [31]:
model.to(device)
model.eval()

MatchERT(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Lin

In [32]:
loaders.query.dataset.desc_name, loaders.query.dataset.desc_name, loaders.query.dataset.data_dir

('r50_gldv1',
 'r50_gldv1',
 '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt')

In [33]:
loaders.query

<torch.utils.data.dataloader.DataLoader at 0x7f4b4bee4650>

In [34]:
nn_inds_path = osp.join(loaders.query.dataset.data_dir, set_name+'_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

In [35]:
cache_nn_inds.size()

torch.Size([1257, 100])

In [36]:
num_samples, top_k = cache_nn_inds.size()
top_k = min(top_k, 100)
top_k

100

In [37]:
def evaluate(
        model: nn.Module,
        cache_nn_inds: torch.Tensor,
        query_loader: DataLoader,
        gallery_loader: DataLoader,
        recall: List[int]):
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)

    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            
        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)

        torch.cuda.empty_cache()
        evaluate_function = partial(mean_average_precision_revisited_rerank, model=model, cache_nn_inds=cache_nn_inds,
            query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
            gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
            ks=recall, 
            gnd=query_loader.dataset.gnd_data,
        )
        metrics = evaluate_function()
    return metrics 

In [43]:
import numpy as np
import bisect, torch
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F

from delf import datum_io
from delf import feature_io


class FeatureDataset(Dataset):
    def __init__(self, 
            data_dir: str,
            samples: list,
            desc_name: str, 
            max_sequence_len: int,
            gnd_data=None,
            is_gallery=False):
        self.data_dir = data_dir
        self.desc_name = desc_name
        self.categories = sorted(list(set([int(entry[1]) for entry in samples])))
        self.cat_to_label = dict(zip(self.categories, range(len(self.categories))))
        self.samples = [(entry[0], self.cat_to_label[entry[1]], entry[2], entry[3]) for entry in samples]
        self.targets = [entry[1] for entry in self.samples]
        self.gnd_data = gnd_data
        self.max_sequence_len = max_sequence_len
        self.scales = [0.5, 0.70710677, 1., 1.41421354, 2., 2.82842708, 4.]
        self.is_gallery = is_gallery
  
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index, load_image=False):
        '''
        Output
            global_desc: (2048, )
            local_desc: (max_sequence_len, 128)
            local_mask: (max_sequence_len, )
            scale_inds: (max_sequence_len, )
            positions: (max_sequence_len, 2)
            label: int
            name: str
        '''
        if self.is_gallery:
            ddd
            e
            
        else:         
            image_path, label, width, height = self.samples[index]
            image_name  = osp.splitext(osp.basename(image_path))[0]
        global_path = osp.join(self.data_dir, 'delg_%s'%self.desc_name, image_name+'.delg_global')
        local_path  = osp.join(self.data_dir, 'delg_%s'%self.desc_name, image_name+'.delg_local')
        assert(osp.exists(global_path) and osp.exists(local_path))
        global_desc = datum_io.ReadFromFile(global_path)
        locations, scales, desc, attention, _ = feature_io.ReadFromFile(local_path)

        local_mask = torch.ones(self.max_sequence_len, dtype=torch.bool)
        local_desc = np.zeros((self.max_sequence_len, 128), dtype=np.float32)
        scale_inds = torch.zeros(self.max_sequence_len).long()
        seq_len = min(desc.shape[0], self.max_sequence_len)
        local_desc[:seq_len] = desc[:seq_len]
        local_mask[:seq_len] = False
        scale_inds[:seq_len] = torch.as_tensor([bisect.bisect_right(self.scales, s) for s in scales[:seq_len]]).long() - 1

        ###############################################
        # Sine embedding
        positions = torch.zeros(self.max_sequence_len, 2).float()
        normx = locations[:, 1]/float(width)
        normy = locations[:, 0]/float(height)
        positions[:seq_len] = torch.from_numpy(np.stack([normx, normy], -1)).float()[:seq_len]
        ##############################################
        
        global_desc = torch.from_numpy(global_desc).float()
        local_desc = torch.from_numpy(local_desc).float()
        
        if load_image:
            image = Image.open(osp.join(self.data_dir, image_path)).convert('RGB')
            image = image.resize((512, 512))
            return F.to_tensor(image), global_desc, local_desc, local_mask, scale_inds, positions, label, image_name
        else:
            return global_desc, local_desc, local_mask, scale_inds, positions, label, image_name

In [38]:
for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
    print(entry)

Extracting query features:   5%|▊                | 4/79 [00:02<02:17,  1.83s/it]

[tensor([[-0.0162, -0.0013,  0.0063,  ..., -0.0145, -0.0001,  0.0082],
        [-0.0099, -0.0030, -0.0368,  ...,  0.0020, -0.0415,  0.0206],
        [-0.0171, -0.0417,  0.0097,  ...,  0.0153, -0.0200,  0.0203],
        ...,
        [-0.0114, -0.0201,  0.0056,  ..., -0.0127,  0.0148,  0.0005],
        [-0.0043, -0.0110, -0.0149,  ..., -0.0388,  0.0107,  0.0375],
        [ 0.0356, -0.0568,  0.0036,  ..., -0.0120, -0.0098, -0.0033]]), tensor([[[ 1.7209e-02,  5.5958e-02, -4.9315e-02,  ..., -2.7391e-02,
           1.1261e-01,  1.4778e-02],
         [ 7.2273e-03, -1.0292e-01,  5.4289e-02,  ..., -5.6300e-02,
           7.2462e-02,  3.1394e-02],
         [-4.1257e-03, -6.9661e-02,  1.2347e-01,  ..., -8.5311e-02,
           9.6801e-02,  5.7327e-02],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0

Extracting query features:   9%|█▌               | 7/79 [00:02<01:32,  1.29s/it]

[tensor([[ 1.8353e-02,  9.4984e-03,  1.9350e-02,  ..., -2.7219e-02,
          7.7322e-03,  3.1169e-02],
        [-1.8768e-02, -1.3033e-02, -3.5214e-02,  ...,  2.2803e-02,
         -1.0488e-02,  1.0525e-02],
        [-3.9730e-02, -3.5537e-02,  1.2602e-02,  ..., -7.5975e-05,
         -2.0404e-02, -1.9551e-02],
        ...,
        [ 1.2958e-02,  1.1159e-02, -1.8014e-02,  ..., -2.9696e-02,
          2.3081e-02, -3.4312e-02],
        [-2.3296e-02, -3.0603e-02,  2.8361e-02,  ..., -1.0669e-03,
          7.7911e-03, -2.5719e-02],
        [-2.6021e-02, -3.6260e-02, -8.4965e-03,  ..., -2.6328e-02,
          2.9091e-02,  1.7441e-02]]), tensor([[[-8.3378e-02, -8.3537e-02,  8.3029e-02,  ..., -8.7802e-02,
           4.2215e-02,  9.1555e-02],
         [-2.1533e-02, -1.1049e-03,  4.2682e-03,  ..., -7.9041e-02,
           7.4096e-02,  4.6654e-03],
         [ 1.3030e-02, -1.0785e-02,  7.6041e-02,  ..., -2.4306e-02,
           2.0201e-01,  1.5075e-01],
         ...,
         [-8.4002e-02, -1.3864e-01,  

Extracting query features:  11%|█▉               | 9/79 [00:04<01:20,  1.16s/it]

[tensor([[ 0.0014, -0.0514, -0.0003,  ...,  0.0260, -0.0160, -0.0079],
        [ 0.0174, -0.0223, -0.0176,  ..., -0.0066,  0.0028,  0.0329],
        [-0.0169, -0.0104,  0.0096,  ...,  0.0047,  0.0099,  0.0425],
        ...,
        [ 0.0424, -0.0042,  0.0232,  ...,  0.0157,  0.0061,  0.0198],
        [ 0.0248, -0.0353,  0.0422,  ..., -0.0010, -0.0121,  0.0144],
        [-0.0022, -0.0170, -0.0165,  ...,  0.0014, -0.0287,  0.0497]]), tensor([[[ 1.9440e-02, -1.3273e-01,  3.3401e-02,  ...,  1.1695e-01,
           1.1231e-01,  1.2493e-01],
         [-5.0590e-02, -2.8670e-02,  1.6828e-01,  ...,  8.8443e-02,
           1.4155e-01, -4.3146e-03],
         [-2.2882e-02, -9.7790e-02, -5.0341e-02,  ...,  1.9036e-02,
          -1.4438e-01, -2.3692e-02],
         ...,
         [-5.0051e-02, -5.2895e-02, -5.1873e-02,  ...,  2.8052e-02,
          -1.0179e-01, -4.4653e-02],
         [-4.7761e-02, -3.6818e-02, -4.4989e-02,  ...,  2.1315e-01,
           1.0343e-01,  1.6068e-03],
         [-4.3767e-03,  3

Extracting query features:  19%|███             | 15/79 [00:04<00:38,  1.66it/s]

[tensor([[ 0.0339,  0.0172,  0.0059,  ...,  0.0098,  0.0133, -0.0289],
        [ 0.0339,  0.0172,  0.0059,  ...,  0.0098,  0.0133, -0.0289],
        [ 0.0339,  0.0172,  0.0059,  ...,  0.0098,  0.0133, -0.0289],
        ...,
        [-0.0152,  0.0093, -0.0064,  ...,  0.0007, -0.0007,  0.0187],
        [ 0.0221,  0.0237, -0.0103,  ..., -0.0042,  0.0079,  0.0282],
        [-0.0026, -0.0318, -0.0285,  ..., -0.0099,  0.0022,  0.0110]]), tensor([[[-6.7854e-03,  9.4361e-02, -1.3527e-01,  ...,  2.4072e-01,
           1.2675e-01, -1.7645e-02],
         [-6.9047e-03, -2.8390e-02,  1.5552e-01,  ..., -1.9563e-01,
          -1.5476e-02,  4.8301e-02],
         [-9.1595e-03, -9.3754e-02, -2.9893e-02,  ...,  7.3210e-02,
          -5.3392e-04,  3.7405e-02],
         ...,
         [-9.2488e-02,  6.9475e-04, -1.2048e-01,  ..., -2.2914e-03,
           4.1748e-02, -1.3660e-01],
         [-2.2852e-02, -1.0689e-01,  2.1268e-02,  ...,  8.0516e-02,
          -1.3669e-01, -1.1589e-01],
         [-7.2119e-03, -1

Extracting query features:  25%|████            | 20/79 [00:06<00:28,  2.05it/s]

[tensor([[ 0.0115, -0.0237, -0.0003,  ..., -0.0041, -0.0315,  0.0384],
        [ 0.0115, -0.0237, -0.0003,  ..., -0.0041, -0.0315,  0.0384],
        [ 0.0079,  0.0013,  0.0404,  ...,  0.0290,  0.0328,  0.0061],
        ...,
        [ 0.0149, -0.0252, -0.0238,  ..., -0.0017, -0.0029,  0.0186],
        [-0.0339, -0.0126, -0.0126,  ...,  0.0007,  0.0239,  0.0395],
        [-0.0051,  0.0179,  0.0409,  ...,  0.0275, -0.0134, -0.0028]]), tensor([[[-0.1660, -0.1330, -0.0097,  ..., -0.1494,  0.0657,  0.0153],
         [-0.1361, -0.0052,  0.0993,  ..., -0.1556,  0.0582,  0.0170],
         [-0.0155,  0.0510, -0.0268,  ..., -0.0316, -0.1293, -0.0330],
         ...,
         [-0.1129, -0.1715,  0.0129,  ..., -0.1438, -0.0030, -0.0765],
         [ 0.0210, -0.1657,  0.0464,  ..., -0.1151, -0.0019,  0.0133],
         [-0.0515,  0.2398,  0.0040,  ..., -0.0666,  0.0361,  0.0482]],

        [[-0.1660, -0.1330, -0.0097,  ..., -0.1494,  0.0657,  0.0153],
         [-0.1361, -0.0052,  0.0993,  ..., -0.1556,

Extracting query features:  29%|████▋           | 23/79 [00:06<00:19,  2.84it/s]

[tensor([[-0.0446, -0.0046, -0.0196,  ...,  0.0001,  0.0036,  0.0299],
        [-0.0341, -0.0085, -0.0020,  ...,  0.0185, -0.0040,  0.0206],
        [-0.0513, -0.0095, -0.0181,  ..., -0.0034, -0.0116, -0.0006],
        ...,
        [-0.0112,  0.0111,  0.0284,  ..., -0.0016, -0.0143,  0.0235],
        [-0.0110, -0.0063,  0.0292,  ...,  0.0103, -0.0075,  0.0017],
        [-0.0069,  0.0244,  0.0048,  ...,  0.0257, -0.0329, -0.0089]]), tensor([[[ 1.7513e-02, -3.8702e-02, -3.1553e-02,  ...,  2.3596e-01,
           1.1900e-01,  2.0688e-02],
         [-2.4768e-02,  1.2546e-01, -6.1105e-03,  ...,  2.7171e-01,
           4.7032e-02, -8.9659e-03],
         [ 7.2460e-02, -6.7673e-02,  5.0094e-02,  ..., -4.7188e-02,
           4.3773e-02, -1.1510e-01],
         ...,
         [ 7.8845e-02,  2.9774e-02,  1.9386e-02,  ..., -1.5836e-03,
          -1.2237e-01, -5.4291e-03],
         [-9.4362e-02,  1.2942e-01, -4.8457e-02,  ...,  7.3248e-02,
          -6.1698e-02,  2.1083e-02],
         [-1.4474e-02,  2

Extracting query features:  32%|█████           | 25/79 [00:08<00:28,  1.91it/s]

[tensor([[-5.4635e-03,  2.3918e-02,  3.9072e-03,  ...,  3.5740e-02,
          5.0756e-02, -4.5329e-03],
        [-5.2382e-03, -1.3395e-02,  1.8875e-03,  ..., -9.5184e-03,
          1.1699e-02,  2.9202e-02],
        [-2.0126e-02, -4.7620e-02,  2.2942e-02,  ...,  3.1967e-02,
         -1.1952e-02,  4.2707e-02],
        ...,
        [-5.3436e-03, -4.0452e-02,  7.2088e-03,  ..., -5.2289e-03,
         -5.6310e-02,  2.3366e-02],
        [ 1.6049e-02, -1.1305e-02, -1.7944e-02,  ...,  4.4254e-03,
         -2.1614e-02,  2.4473e-02],
        [-1.1295e-02, -5.1032e-02,  1.7590e-03,  ..., -2.5957e-05,
         -5.1616e-03,  1.6226e-02]]), tensor([[[-0.0246, -0.0361,  0.0217,  ..., -0.0273, -0.0379, -0.0264],
         [-0.0738, -0.1164,  0.0205,  ..., -0.0661, -0.0466, -0.1460],
         [ 0.0267, -0.0799,  0.1668,  ..., -0.1550, -0.0531, -0.0513],
         ...,
         [-0.0323,  0.0917, -0.0505,  ..., -0.1556,  0.0893,  0.0353],
         [-0.0582,  0.0879,  0.0892,  ...,  0.1061,  0.0452, -0.0816

Extracting query features:  39%|██████▎         | 31/79 [00:09<00:13,  3.45it/s]

[tensor([[ 0.0298, -0.0028,  0.0123,  ..., -0.0248, -0.0171, -0.0417],
        [ 0.0102, -0.0105,  0.0085,  ..., -0.0193, -0.0071, -0.0076],
        [-0.0374, -0.0464,  0.0356,  ..., -0.0258,  0.0142,  0.0546],
        ...,
        [-0.0058, -0.0168, -0.0223,  ..., -0.0144, -0.0032,  0.0432],
        [ 0.0124, -0.0002, -0.0040,  ...,  0.0392, -0.0113,  0.0097],
        [ 0.0013, -0.0224,  0.0142,  ...,  0.0214, -0.0048, -0.0091]]), tensor([[[-0.0239, -0.0699,  0.0120,  ..., -0.0527, -0.0812,  0.0383],
         [ 0.0067, -0.1157,  0.0232,  ...,  0.1799,  0.0263,  0.1193],
         [-0.0269, -0.0584,  0.1340,  ...,  0.1298,  0.0393, -0.0618],
         ...,
         [-0.0257,  0.0113,  0.0286,  ..., -0.0812, -0.0404,  0.0015],
         [ 0.0460, -0.0674,  0.0216,  ...,  0.1082, -0.1537,  0.0566],
         [ 0.0718,  0.0059, -0.0426,  ..., -0.0145, -0.0561,  0.0138]],

        [[-0.0705,  0.0732,  0.1890,  ..., -0.0644, -0.0291,  0.0014],
         [-0.0609,  0.0431,  0.2669,  ..., -0.1068,

Extracting query features:  42%|██████▋         | 33/79 [00:10<00:18,  2.47it/s]

[tensor([[ 0.0114,  0.0028, -0.0102,  ..., -0.0056,  0.0279, -0.0086],
        [ 0.0114,  0.0028, -0.0102,  ..., -0.0056,  0.0279, -0.0086],
        [ 0.0114,  0.0028, -0.0102,  ..., -0.0056,  0.0279, -0.0086],
        ...,
        [-0.0373, -0.0269,  0.0144,  ...,  0.0489,  0.0227, -0.0325],
        [-0.0225,  0.0048, -0.0210,  ..., -0.0241,  0.0001,  0.0058],
        [ 0.0320, -0.0159,  0.0016,  ...,  0.0011, -0.0387,  0.0574]]), tensor([[[-0.0571, -0.1341,  0.0446,  ..., -0.0484, -0.0113, -0.0280],
         [-0.0389,  0.0040,  0.0847,  ...,  0.1179,  0.1127, -0.1772],
         [-0.0266, -0.0790,  0.1381,  ..., -0.0374,  0.1125, -0.0704],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0571, -0.1341,  0.0446,  ..., -0.0484, -0.0113, -0.0280],
         [-0.0389,  0.0040,  0.0847,  ...,  0.1179,

Extracting query features:  49%|███████▉        | 39/79 [00:11<00:10,  3.84it/s]

[tensor([[ 0.0101, -0.0613, -0.0353,  ..., -0.0343, -0.0221,  0.0193],
        [ 0.0352, -0.0054, -0.0084,  ...,  0.0419, -0.0239, -0.0082],
        [-0.0034, -0.0469, -0.0138,  ...,  0.0028,  0.0402, -0.0329],
        ...,
        [-0.0049,  0.0080,  0.0123,  ...,  0.0399,  0.0016,  0.0111],
        [-0.0130, -0.0324,  0.0202,  ...,  0.0027, -0.0246,  0.0399],
        [ 0.0176,  0.0183,  0.0257,  ...,  0.0010,  0.0069,  0.0683]]), tensor([[[-0.0965, -0.0456, -0.0011,  ..., -0.0638, -0.0788, -0.0609],
         [-0.1075, -0.0612, -0.0068,  ..., -0.0577, -0.0450, -0.0423],
         [-0.1662, -0.0021,  0.0436,  ..., -0.0575, -0.0576,  0.0945],
         ...,
         [-0.0794, -0.0568, -0.1010,  ..., -0.0225, -0.0697, -0.1020],
         [ 0.0474,  0.0087, -0.0801,  ...,  0.0410,  0.0064,  0.1960],
         [-0.0227, -0.1119,  0.0988,  ...,  0.0196, -0.0343, -0.0313]],

        [[-0.0093,  0.0418, -0.0563,  ...,  0.0200,  0.0828, -0.0756],
         [-0.0136, -0.1384,  0.0269,  ..., -0.0495,

Extracting query features:  54%|████████▋       | 43/79 [00:12<00:10,  3.56it/s]

[tensor([[ 0.0156, -0.0235,  0.0221,  ..., -0.0202,  0.0055,  0.0166],
        [-0.0296,  0.0065, -0.0047,  ..., -0.0227,  0.0128,  0.0117],
        [-0.0090,  0.0104,  0.0204,  ...,  0.0238, -0.0158,  0.0109],
        ...,
        [-0.0031,  0.0050, -0.0408,  ...,  0.0396,  0.0136,  0.0059],
        [-0.0131, -0.0188,  0.0107,  ..., -0.0411, -0.0212, -0.0065],
        [-0.0195, -0.0168,  0.0012,  ...,  0.0012,  0.0113,  0.0071]]), tensor([[[ 0.0189, -0.0375,  0.0112,  ...,  0.0150, -0.0258, -0.0446],
         [-0.0945, -0.0808,  0.0517,  ..., -0.0085,  0.0385,  0.0300],
         [ 0.0291, -0.0662,  0.0501,  ...,  0.0593,  0.0517, -0.0972],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.1082, -0.0089, -0.0726,  ..., -0.0713, -0.0553, -0.1291],
         [ 0.0090, -0.0266, -0.0740,  ...,  0.0100,

Extracting query features:  61%|█████████▋      | 48/79 [00:13<00:07,  4.33it/s]

[tensor([[ 0.0246, -0.0056, -0.0090,  ...,  0.0026, -0.0078,  0.0053],
        [ 0.0246, -0.0056, -0.0090,  ...,  0.0026, -0.0078,  0.0053],
        [ 0.0159, -0.0439, -0.0205,  ..., -0.0136, -0.0211,  0.0348],
        ...,
        [-0.0136, -0.0090, -0.0061,  ...,  0.0257, -0.0347, -0.0261],
        [-0.0132, -0.0135, -0.0182,  ..., -0.0243,  0.0036,  0.0197],
        [-0.0132, -0.0135, -0.0182,  ..., -0.0243,  0.0036,  0.0197]]), tensor([[[ 2.0404e-02, -8.3335e-02,  1.4865e-01,  ..., -3.4991e-03,
           1.8933e-01,  8.7847e-02],
         [-2.1826e-05, -1.0137e-01,  9.0906e-02,  ..., -6.1755e-02,
           2.0346e-01,  1.0591e-01],
         [-7.3310e-03, -9.2728e-02,  6.3167e-02,  ...,  1.6887e-02,
           1.3625e-01,  7.5680e-02],
         ...,
         [-8.1631e-02,  1.3109e-01,  3.4995e-02,  ..., -1.0411e-01,
           5.6814e-03,  9.2246e-02],
         [-1.6940e-01, -1.1182e-01,  6.0297e-02,  ..., -7.3432e-02,
           2.8684e-02, -7.6336e-02],
         [ 7.3441e-03,  8

Extracting query features:  63%|██████████▏     | 50/79 [00:15<00:10,  2.72it/s]

[tensor([[-0.0115, -0.0126,  0.0165,  ...,  0.0255,  0.0209,  0.0337],
        [ 0.0071, -0.0126, -0.0163,  ..., -0.0199,  0.0372,  0.0494],
        [-0.0586, -0.0448, -0.0260,  ..., -0.0121, -0.0100, -0.0103],
        ...,
        [-0.0269,  0.0119,  0.0388,  ..., -0.0216,  0.0072,  0.0030],
        [-0.0046,  0.0224, -0.0180,  ..., -0.0082,  0.0276, -0.0214],
        [ 0.0035, -0.0356, -0.0103,  ...,  0.0017, -0.0517,  0.0235]]), tensor([[[-3.7733e-02, -9.0507e-02,  4.4485e-02,  ..., -1.3104e-01,
           8.1174e-02, -2.4398e-03],
         [ 1.7870e-03, -1.3424e-01,  1.0700e-01,  ...,  1.5500e-02,
          -3.6547e-02, -1.6777e-01],
         [-4.3808e-02, -5.0493e-02, -1.3679e-02,  ..., -4.8036e-02,
           6.1270e-02,  6.3921e-03],
         ...,
         [ 6.2603e-02, -7.4421e-02, -6.0736e-02,  ..., -5.5511e-02,
           1.4301e-02, -4.5927e-02],
         [-3.2953e-02, -4.2730e-02,  9.5609e-02,  ..., -1.6589e-01,
          -7.2331e-02, -1.0391e-01],
         [ 1.4714e-03, -1

Extracting query features:  70%|███████████▏    | 55/79 [00:15<00:06,  3.70it/s]

[tensor([[ 2.1894e-02,  1.6739e-02,  1.6247e-02,  ..., -3.5543e-03,
          3.2740e-02,  1.1845e-02],
        [-3.4889e-03,  9.9571e-03, -1.7896e-02,  ..., -6.4725e-03,
          4.3118e-02, -3.4374e-03],
        [-2.7423e-05, -1.2489e-02,  2.7307e-02,  ...,  1.3168e-02,
         -8.3900e-03, -1.9452e-02],
        ...,
        [-6.7583e-03, -2.0856e-02, -1.3843e-02,  ...,  2.1081e-03,
          2.9367e-02,  1.8925e-02],
        [-2.7292e-02,  2.9215e-02,  5.6893e-03,  ..., -1.0902e-02,
         -1.0940e-02,  1.6029e-02],
        [-1.4015e-02, -3.4707e-02, -1.1920e-02,  ...,  5.6461e-03,
         -1.0190e-02,  2.7918e-02]]), tensor([[[ 0.0304,  0.1265,  0.0238,  ..., -0.1370,  0.1579,  0.0578],
         [ 0.0523,  0.0102,  0.0055,  ..., -0.0794, -0.0967,  0.0103],
         [-0.1025,  0.0377, -0.0075,  ..., -0.0467,  0.0504, -0.0243],
         ...,
         [ 0.0074, -0.1272, -0.0948,  ..., -0.0137, -0.1725,  0.0347],
         [-0.0009,  0.1000, -0.0734,  ...,  0.0222, -0.0330, -0.1288

Extracting query features:  72%|███████████▌    | 57/79 [00:16<00:06,  3.52it/s]

[tensor([[-0.0155, -0.0490,  0.0416,  ...,  0.0085,  0.0158,  0.0062],
        [-0.0143, -0.0025,  0.0009,  ..., -0.0131, -0.0122,  0.0258],
        [-0.0221, -0.0500,  0.0101,  ..., -0.0024, -0.0099,  0.0586],
        ...,
        [-0.0107, -0.0192,  0.0002,  ..., -0.0051,  0.0069,  0.0028],
        [-0.0107, -0.0192,  0.0002,  ..., -0.0051,  0.0069,  0.0028],
        [-0.0107, -0.0192,  0.0002,  ..., -0.0051,  0.0069,  0.0028]]), tensor([[[ 0.1684,  0.0586,  0.0239,  ..., -0.0162, -0.0403, -0.0403],
         [-0.1702,  0.0553, -0.0248,  ...,  0.1914,  0.0171, -0.0814],
         [ 0.1387,  0.0624, -0.0387,  ..., -0.0675, -0.0143, -0.0855],
         ...,
         [ 0.0514, -0.0243, -0.0084,  ...,  0.0249,  0.0404,  0.0364],
         [-0.1146, -0.0333,  0.0066,  ...,  0.0861,  0.0941,  0.0904],
         [-0.0429, -0.0697, -0.0322,  ..., -0.0266,  0.0699,  0.0735]],

        [[ 0.0008,  0.0355, -0.1268,  ...,  0.0168,  0.0289, -0.0607],
         [ 0.0587,  0.0726,  0.0036,  ..., -0.1582,

Extracting query features:  75%|███████████▉    | 59/79 [00:17<00:05,  3.48it/s]

[tensor([[ 0.0031, -0.0318,  0.0316,  ..., -0.0065, -0.0046, -0.0333],
        [ 0.0089, -0.0081, -0.0384,  ...,  0.0276, -0.0236, -0.0115],
        [-0.0459,  0.0098,  0.0117,  ...,  0.0269, -0.0105, -0.0114],
        ...,
        [ 0.0016, -0.0057, -0.0012,  ...,  0.0090, -0.0126,  0.0364],
        [ 0.0016, -0.0057, -0.0012,  ...,  0.0090, -0.0126,  0.0364],
        [ 0.0193, -0.0088,  0.0669,  ..., -0.0039, -0.0227,  0.0241]]), tensor([[[-0.0670,  0.0328, -0.1227,  ...,  0.0240,  0.0723,  0.1383],
         [-0.0166, -0.0256, -0.0009,  ...,  0.1868,  0.1373, -0.1362],
         [-0.0996,  0.0563, -0.0807,  ...,  0.1310,  0.0613,  0.2167],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0349, -0.1399, -0.0304,  ...,  0.0254, -0.1576,  0.0191],
         [-0.1210,  0.0392, -0.0791,  ..., -0.0134,

Extracting query features:  81%|████████████▉   | 64/79 [00:18<00:04,  3.08it/s]

[tensor([[ 0.0508, -0.0212, -0.0006,  ...,  0.0216,  0.0052,  0.0309],
        [-0.0078, -0.0018, -0.0371,  ..., -0.0038,  0.0022,  0.0156],
        [-0.0365, -0.0136, -0.0108,  ...,  0.0077,  0.0031,  0.0183],
        ...,
        [-0.0176,  0.0141,  0.0152,  ..., -0.0159,  0.0235, -0.0205],
        [-0.0205, -0.0262, -0.0033,  ..., -0.0342, -0.0046,  0.0161],
        [ 0.0230, -0.0194,  0.0185,  ..., -0.0117, -0.0066,  0.0054]]), tensor([[[ 2.1057e-02, -3.1634e-02,  4.2201e-02,  ...,  7.4848e-03,
           6.9685e-02,  5.6428e-02],
         [ 1.3418e-02,  2.4903e-02,  2.2280e-02,  ..., -5.9683e-02,
           1.2966e-01, -4.6822e-02],
         [ 5.2379e-03,  4.5938e-02,  7.4622e-02,  ..., -9.7540e-02,
           1.5496e-01, -2.7975e-02],
         ...,
         [-1.1064e-02, -3.2048e-02, -1.0487e-01,  ..., -6.6189e-02,
           8.4910e-02, -2.8883e-02],
         [ 3.5333e-02,  4.5756e-02,  3.1003e-02,  ..., -5.0994e-02,
          -1.0908e-01,  1.3876e-01],
         [-5.4903e-02,  1

Extracting query features:  84%|█████████████▎  | 66/79 [00:18<00:03,  3.51it/s]

[tensor([[-0.0102,  0.0218,  0.0313,  ...,  0.0164,  0.0066, -0.0177],
        [ 0.0043, -0.0133, -0.0151,  ...,  0.0087,  0.0003,  0.0288],
        [-0.0018, -0.0340,  0.0055,  ...,  0.0152, -0.0074,  0.0034],
        ...,
        [ 0.0120,  0.0082,  0.0052,  ..., -0.0015, -0.0002,  0.0152],
        [ 0.0349,  0.0182,  0.0047,  ..., -0.0126, -0.0133,  0.0114],
        [ 0.0246,  0.0029,  0.0047,  ...,  0.0122,  0.0132, -0.0056]]), tensor([[[-1.0036e-01, -2.0984e-02,  1.1513e-01,  ...,  7.9754e-02,
           1.3279e-01, -5.9348e-02],
         [-1.4621e-01,  4.5775e-02,  1.5653e-01,  ..., -1.0559e-01,
          -3.2110e-02,  1.5614e-01],
         [-2.4236e-03, -9.5620e-02, -4.3752e-02,  ...,  4.5184e-03,
           4.6032e-02, -4.9231e-02],
         ...,
         [ 9.6214e-03, -1.5160e-01, -1.1732e-01,  ...,  3.8078e-02,
           4.0244e-02, -3.1325e-02],
         [ 4.5461e-02, -1.2253e-01,  1.4072e-02,  ...,  3.6694e-02,
          -8.1308e-02,  1.0091e-01],
         [-7.8753e-02, -1

[tensor([[ 0.0253,  0.0192,  0.0007,  ...,  0.0212,  0.0212,  0.0143],
        [-0.0236,  0.0069,  0.0070,  ...,  0.0210,  0.0344,  0.0279],
        [-0.0057, -0.0211, -0.0056,  ...,  0.0090, -0.0062,  0.0018],
        ...,
        [ 0.0113, -0.0286,  0.0436,  ..., -0.0468, -0.0187,  0.0075],
        [ 0.0113, -0.0286,  0.0436,  ..., -0.0468, -0.0187,  0.0075],
        [ 0.0202, -0.0198, -0.0247,  ..., -0.0425, -0.0268, -0.0099]]), tensor([[[-5.1862e-02, -3.4314e-02,  7.0362e-02,  ..., -5.4797e-02,
          -2.7188e-02, -4.5747e-02],
         [-1.0171e-01, -1.7552e-02,  9.9746e-02,  ..., -6.5732e-02,
           6.4400e-02, -1.2654e-02],
         [-9.7510e-03,  5.4382e-02, -6.6799e-02,  ...,  5.6094e-02,
           1.0254e-01,  9.9895e-03],
         ...,
         [ 8.3630e-02, -7.0886e-02, -9.2674e-02,  ..., -4.6242e-02,
           1.2375e-01, -1.3009e-01],
         [ 3.6420e-03,  1.0387e-01,  1.6344e-01,  ...,  3.2685e-02,
          -1.5680e-02, -2.8776e-02],
         [-9.3485e-02,  2

Extracting query features:  91%|██████████████▌ | 72/79 [00:19<00:02,  3.41it/s]

[tensor([[ 0.0021,  0.0161, -0.0010,  ...,  0.0529,  0.0001, -0.0020],
        [ 0.0086, -0.0272, -0.0122,  ..., -0.0331, -0.0108,  0.0312],
        [ 0.0086, -0.0272, -0.0122,  ..., -0.0331, -0.0108,  0.0312],
        ...,
        [-0.0050, -0.0335,  0.0233,  ...,  0.0025, -0.0275,  0.0344],
        [-0.0169,  0.0263,  0.0150,  ...,  0.0524,  0.0265,  0.0109],
        [-0.0150,  0.0097,  0.0237,  ..., -0.0028, -0.0190, -0.0215]]), tensor([[[-0.0587,  0.0369, -0.0783,  ...,  0.2120,  0.0758, -0.0566],
         [-0.0535,  0.0993, -0.0539,  ...,  0.1470,  0.0700,  0.0380],
         [ 0.0019,  0.0663, -0.0201,  ...,  0.1835,  0.0987, -0.0692],
         ...,
         [-0.0740, -0.0229, -0.1722,  ...,  0.0372,  0.0998,  0.0979],
         [ 0.0470, -0.0671, -0.1546,  ...,  0.0658, -0.0665,  0.1821],
         [-0.0031,  0.2130, -0.0845,  ...,  0.1298, -0.0045, -0.0045]],

        [[-0.0575, -0.1075, -0.0258,  ...,  0.1303, -0.0256, -0.1435],
         [-0.1597, -0.1349, -0.0592,  ..., -0.0142,

Extracting query features:  94%|██████████████▉ | 74/79 [00:20<00:01,  3.97it/s]

[tensor([[-0.0096, -0.0371, -0.0285,  ...,  0.0322,  0.0081,  0.0414],
        [-0.0234, -0.0165, -0.0149,  ..., -0.0092,  0.0321,  0.0207],
        [-0.0246, -0.0098,  0.0217,  ...,  0.0433,  0.0033,  0.0268],
        ...,
        [ 0.0289, -0.0197,  0.0065,  ..., -0.0201, -0.0183, -0.0493],
        [ 0.0184, -0.0078, -0.0040,  ..., -0.0027, -0.0145,  0.0243],
        [ 0.0184, -0.0078, -0.0040,  ..., -0.0027, -0.0145,  0.0243]]), tensor([[[ 0.0817, -0.0189, -0.0076,  ..., -0.0231,  0.0345,  0.0636],
         [ 0.0481, -0.0322, -0.0248,  ..., -0.0124,  0.0679,  0.0912],
         [-0.0816, -0.1463, -0.0992,  ...,  0.1024,  0.1618,  0.0184],
         ...,
         [-0.1366,  0.0222, -0.1319,  ...,  0.0217, -0.1149,  0.0038],
         [-0.0689, -0.0787, -0.0324,  ...,  0.0167, -0.0276,  0.0136],
         [ 0.0189,  0.1381, -0.0811,  ..., -0.1190,  0.1031,  0.0408]],

        [[-0.1066, -0.0780, -0.0502,  ...,  0.0929,  0.1264,  0.0546],
         [-0.0221,  0.0234, -0.0724,  ..., -0.0036,

Extracting query features:  96%|███████████████▍| 76/79 [00:20<00:00,  4.16it/s]

[tensor([[ 0.0256,  0.0178, -0.0074,  ..., -0.0040,  0.0246,  0.0223],
        [-0.0056, -0.0425, -0.0020,  ..., -0.0003, -0.0089,  0.0207],
        [ 0.0205, -0.0027,  0.0107,  ...,  0.0325,  0.0368,  0.0325],
        ...,
        [-0.0178,  0.0065, -0.0450,  ..., -0.0289, -0.0161, -0.0030],
        [ 0.0049,  0.0393,  0.0208,  ..., -0.0046,  0.0170,  0.0086],
        [-0.0042, -0.0032, -0.0034,  ...,  0.0173,  0.0237,  0.0089]]), tensor([[[-0.0108,  0.0332,  0.1039,  ...,  0.1034,  0.1023, -0.1376],
         [-0.0347, -0.1071, -0.0427,  ...,  0.1281,  0.1952, -0.1256],
         [ 0.0333, -0.0772, -0.0215,  ..., -0.1366,  0.1644,  0.0670],
         ...,
         [ 0.0330, -0.0626, -0.1078,  ..., -0.0731, -0.0083, -0.0617],
         [-0.0927,  0.0579,  0.0306,  ...,  0.0047, -0.0064, -0.1209],
         [-0.0753, -0.0076, -0.1170,  ...,  0.1225, -0.0955,  0.0271]],

        [[-0.0346, -0.0423, -0.0261,  ...,  0.1305, -0.0137,  0.0707],
         [-0.0203, -0.0257, -0.0321,  ...,  0.0283,

                                                                                

In [39]:
eval_function = partial(evaluate, model=model, 
        cache_nn_inds=cache_nn_inds,
        recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery)

In [40]:
metrics = eval_function()

                                                                                

ValueError: Caught ValueError in DataLoader worker process 5.
Original Traceback (most recent call last):
  File "/mnt/beegfs/home/smessoud/anaconda3/envs/rrt/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/mnt/beegfs/home/smessoud/anaconda3/envs/rrt/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/mnt/beegfs/home/smessoud/anaconda3/envs/rrt/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/utils/data/dataset.py", line 55, in __getitem__
    local_desc[:seq_len] = desc[:seq_len]
ValueError: could not broadcast input array from shape (0,) into shape (0,128)


In [None]:
pprint(metrics)

In [None]:
use_aqe = False
aqe_params = {'k': 2, 'alpha': 0.3}

save_nn_inds = True

In [None]:
with open(osp.join(data_dir,  set_name+'_query.txt')) as fid:
    query_lines   = fid.read().splitlines()

In [None]:
len(query_lines)

In [None]:
with open(osp.join(data_dir, set_name+'_gallery.txt')) as fid:
    gallery_lines = fid.read().splitlines()

In [None]:
query_feats = []
for i in tqdm(range(len(query_lines))):
    name = osp.splitext(osp.basename(query_lines[i].split(';;')[0]))[0]
    path = osp.join(data_dir, 'delg_' + feature_name, name + '.delg_global')
    query_feats.append(datum_io.ReadFromFile(path))

In [None]:
query_feats = np.stack(query_feats, axis=0)
query_feats = query_feats / LA.norm(query_feats, axis=-1)[:, None]

In [None]:
query_feats.shape

In [None]:
selection_lines = np.genfromtxt('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/'+
                                set_name+'_selection_imgs.txt', dtype='str')
selection_lines.shape

In [None]:
selection_lines[0][:10]

In [None]:
# wiki_img   = '.'.join((wiki_item['image'].split('.')[:-1]))

In [None]:
selection_index_feats = []
for i in tqdm(range(len(selection_lines))):
    index_feats = []
    for image_file in selection_lines[i]:
        name = '.'.join((image_file.split('.')[:-1]))
        path = osp.join(data_dir, 'delg_' + feature_name, name + '.delg_global')
        index_feats.append(datum_io.ReadFromFile(path))
    selection_index_feats.append(index_feats)

In [None]:
min([len(selection_index_feats[i]) for i in range(len(selection_index_feats))])

In [None]:
selection_index_feats = np.array(selection_index_feats)
selection_index_feats.shape

In [None]:
query_feats[0].shape

In [None]:
selection_sims = []
for i in range(len(selection_index_feats)):
    index_feats = np.stack(selection_index_feats[i], axis=0)
    index_feats = index_feats / LA.norm(index_feats, axis=-1)[:, None]
    selection_sims.append(np.matmul(query_feats[i], index_feats.T))

In [None]:
sims = np.array(selection_sims)
sims.shape

In [None]:
sims

In [None]:
if use_aqe:
    ## WARNING: I WAS TOO LAZY TO CORRECT IT
    ## IF YOU WANNA USE AQE PARAMATER - ADAPT THE CODE FOR VIQUAE-RRT
    alpha = aqe_params['alpha']
    nn_inds = np.argsort(-sims, -1)
    query_aug = deepcopy(query_feats)
    for i in range(len(query_feats)):
        new_q = [query_feats[i]]
        for j in range(aqe_params['k']):
            nn_id = nn_inds[i, j]
            weight = sims[i, nn_id] ** aqe_params['alpha']
            new_q.append(weight * index_feats[nn_id])
        new_q = np.stack(new_q, 0)
        new_q = np.mean(new_q, axis=0)
        query_aug[i] = new_q/LA.norm(new_q, axis=-1)
    sims = np.matmul(query_aug, index_feats.T)

In [None]:
selection_index_feats[0].shape

In [None]:
nn_inds = np.argsort(-sims, -1)
nn_dists = deepcopy(sims)
for i in range(query_feats.shape[0]):
    index_feats = selection_index_feats[i]
    for j in range(index_feats.shape[0]):
        nn_dists[i, j] = sims[i, nn_inds[i, j]]

In [None]:
nn_inds.shape

In [None]:
if save_nn_inds:
    output_path = osp.join(data_dir, set_name + '_nn_inds_%s.pkl' % feature_name)
    pickle_save(output_path, nn_inds)

In [None]:
def compute_ap(ranks, nres):
    """
    Computes average precision for given ranked indexes.
    
    Arguments
    ---------
    ranks : zerro-based ranks of positive images
    nres  : number of positive images
    
    Returns
    -------
    ap    : average precision
    """

    # number of images ranked by the system
    nimgranks = len(ranks)

    # accumulate trapezoids in PR-plot
    ap = 0

    recall_step = 1. / nres

    for j in np.arange(nimgranks):
        rank = ranks[j]

        if rank == 0:
            precision_0 = 1.
        else:
            precision_0 = float(j) / rank

        precision_1 = float(j + 1) / (rank + 1)

        ap += (precision_0 + precision_1) * recall_step / 2.

    return ap

In [None]:
def compute_map(ranks, gnd, kappas=[]):
    """
    Computes the mAP for a given set of returned results.

         Usage: 
           map = compute_map (ranks, gnd) 
                 computes mean average precsion (map) only
        
           map, aps, pr, prs = compute_map (ranks, gnd, kappas) 
                 computes mean average precision (map), average precision (aps) for each query
                 computes mean precision at kappas (pr), precision at kappas (prs) for each query
        
         Notes:
         1) ranks starts from 0, ranks.shape = db_size X #queries
         2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
         3) If there are no positive images for some query, that query is excluded from the evaluation
    """

    map = 0.
    nq = len(gnd) # number of queries
    aps = np.zeros(nq)
    pr = np.zeros(len(kappas))
    prs = np.zeros((nq, len(kappas)))
    nempty = 0

    for i in np.arange(nq):
        qgnd = np.array(gnd[i]['ok'])
        qgndj = np.array(gnd[i]['junk'])

        # no positive images, skip from the average
        if qgnd.shape[0] == 0:
            aps[i] = float('nan')
            prs[i, :] = float('nan')
            nempty += 1
            continue

        # sorted positions of positive and junk images (0 based)
        pos  = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
        if len(pos) == 0:
            pos = np.array(gnd[i]['ok']).max() * np.ones_like(gnd[i]['ok'])
        
        junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]

        k = 0
        ij = 0
        if len(junk):
            # decrease positions of positives based on the number of
            # junk images appearing before them
            ip = 0
            while (ip < len(pos)):
                while (ij < len(junk) and pos[ip] > junk[ij]):
                    k += 1
                    ij += 1
                pos[ip] = pos[ip] - k
                ip += 1

        # compute ap
        ap = compute_ap(pos, len(qgnd))
        map = map + ap
        aps[i] = ap

        # compute precision @ k
        pos += 1 # get it to 1-based
        for j in np.arange(len(kappas)):
            """
            if len(pos) == 0:
                max_pos = kappas[j]
            else: max_pos = max(pos)
            """
            kq = min(max(pos), kappas[j]); 
            prs[i, j] = (pos <= kq).sum() / kq
        pr = pr + prs[i, :]

    map = map / (nq - nempty)
    pr = pr / (nq - nempty)

    return map, aps, pr, prs

In [None]:
gnd_data = pickle_load(osp.join(data_dir, gnd_name))

In [None]:
def compute_metrics(dataset, ranks, gnd, kappas=[1, 5, 10]):
    print(ranks.shape)
    
    # old evaluation protocol
    if dataset.startswith('classic'):
        map, aps, _, _ = compute_map(ranks, gnd)
        out = {'map': np.around(map*100, decimals=3)}
        print('>> {}: mAP {:.2f}'.format(dataset, out['map']))

    # new evaluation protocol
    elif dataset.startswith('viquae'):
        
        gnd_t = []
        for i in range(len(gnd)):
            g = {}
            g['ok'] = np.concatenate([gnd[i]['hard']])
            g['junk'] = np.concatenate([gnd[i]['junk']])
            gnd_t.append(g)
        mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas)


        out = {
            'H_map': np.around(mapH*100, decimals=2),
            'H_mp':  np.around(mprH*100, decimals=2),
        }

        print('>> {}: mAP H: {}'.format(dataset, out['H_map']))
        print('>> {}: mP@k{} H: {}'.format(dataset, kappas, out['H_mp']))

    return out

In [None]:
np.arange(nn_inds.T.shape[0])[np.in1d(nn_inds.T[:,i], gnd_data['gnd'][0]['junk'])]

In [None]:
compute_metrics('viquae', nn_inds.T, gnd_data['gnd'], kappas=[1,5,6,10])