# Imports

In [1]:
import numpy as np
from functools import partial

In [2]:
from copy import deepcopy
from functools import partial
from pprint import pprint
import os.path as osp

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from torch.backends import cudnn

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT
from models.ingredient import model_ingredient, get_model


In [6]:
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from sacred import Ingredient

In [7]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset
from utils.metrics import *



In [8]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset
ex = sacred.Experiment('RRT Evaluation', ingredients=[data_ingredient, model_ingredient], interactive=True)

In [9]:
ex = sacred.Experiment('Prepare Top-K (VIQUAE FOR RTT)', interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [10]:
feature_name = 'r50_gldv1'
set_name = 'tuto'
gnd_name = 'gnd_'+ set_name+'.pkl'

In [11]:
cpu = False
cudnn_flag = 'benchmark'
temp_dir = osp.join('logs', 'temp')
resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt'
seed = 0

In [12]:
dataset_name = 'viquae_for_rrt'
data_dir = osp.join('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data', dataset_name)

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
torch.manual_seed(seed)

<torch._C.Generator at 0x7f9a419357f0>

In [14]:
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [15]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [16]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, test_gnd_file, 
        max_sequence_len=500):
    ####################################################################################################################################
    train_lines     = read_file(osp.join(train_data_dir, train_txt))
    train_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in train_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in query_lines]
    gallery_samples = [(line.split(';;')[0], int(line.split(';;')[1]), int(line.split(';;')[2]), int(line.split(';;')[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)

In [17]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    gallery: Optional[DataLoader] = None

In [18]:
def get_loaders(desc_name, train_data_dir, 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name, 
        train_data_dir=train_data_dir,
        test_data_dir=train_data_dir,
        train_txt='tuto_query.txt',
        test_txt=('tuto_query.txt', 'tuto_selection.txt'),
        test_gnd_file='gnd_tuto.pkl', 
        max_sequence_len=500)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        train_nn_inds = osp.join(train_data_dir, set_name+'_nn_inds_%s.pkl'%desc_name)
        train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories)), recalls

In [19]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [20]:
(train_set, query_train_set), (query_set, gallery_set) = get_sets('r50_gldv1',
            train_data_dir,
            train_data_dir,
            set_name+'_query.txt',
            (set_name+'_query.txt',set_name+'_selection.txt'),
            'gnd_'+set_name+'.pkl',
            500)

In [21]:
batch_size      = 16
test_batch_size = 16
max_sequence_len = 500
sampler = 'random'
if sampler == 'random':
   train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)

In [22]:
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 6, 10]

In [23]:
train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [24]:
query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [25]:
loaders, recall_ks = get_loaders(desc_name='r50_gldv1',
    train_data_dir='/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt', 
    batch_size=36, test_batch_size=36, 
    num_workers=8, pin_memory=True, 
    sampler='random', recalls=[1, 5, 10],
    num_candidates=100)

In [26]:
model_ingredient = Ingredient('model', interactive=True)

In [27]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before):
    return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [28]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.0 
activation = "relu"
normalize_before = False

In [29]:
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before)

In [30]:
if resume is not None:
   checkpoint = torch.load(resume, map_location=torch.device('cpu'))
   model.load_state_dict(checkpoint['state'], strict=True)

In [31]:
model.to(device)
model.eval()

MatchERT(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Lin

In [32]:
loaders.query.dataset.desc_name, loaders.query.dataset.desc_name, loaders.query.dataset.data_dir

('r50_gldv1',
 'r50_gldv1',
 '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt')

In [33]:
loaders.query

<torch.utils.data.dataloader.DataLoader at 0x7f9a3a2a5110>

In [34]:
nn_inds_path = osp.join(loaders.query.dataset.data_dir, set_name+'_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

In [35]:
cache_nn_inds.size()

torch.Size([100, 100])

In [36]:
num_samples, top_k = cache_nn_inds.size()
top_k = min(top_k, 100)
top_k

100

In [37]:
def evaluate(
        model: nn.Module,
        cache_nn_inds: torch.Tensor,
        query_loader: DataLoader,
        gallery_loader: DataLoader,
        recall: List[int]):
    model.eval()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)

    query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
    gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

    with torch.no_grad():
        for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
            q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
            query_global.append(q_global.cpu())
            query_local.append(q_local.cpu())
            query_mask.append(q_mask.cpu())
            query_scales.append(q_scales.cpu())
            query_positions.append(q_positions.cpu())
            query_names.extend(list(q_names))

        query_global    = torch.cat(query_global, 0)
        query_local     = torch.cat(query_local, 0)
        query_mask      = torch.cat(query_mask, 0)
        query_scales    = torch.cat(query_scales, 0)
        query_positions = torch.cat(query_positions, 0)

        for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
            g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
            gallery_global.append(g_global.cpu())
            gallery_local.append(g_local.cpu())
            gallery_mask.append(g_mask.cpu())
            gallery_scales.append(g_scales.cpu())
            gallery_positions.append(g_positions.cpu())
            gallery_names.extend(list(g_names))
            
        gallery_global    = torch.cat(gallery_global, 0)
        gallery_local     = torch.cat(gallery_local, 0)
        gallery_mask      = torch.cat(gallery_mask, 0)
        gallery_scales    = torch.cat(gallery_scales, 0)
        gallery_positions = torch.cat(gallery_positions, 0)

        torch.cuda.empty_cache()
        evaluate_function = partial(mean_average_precision_revisited_rerank, model=model, cache_nn_inds=cache_nn_inds,
            query_global=query_global, query_local=query_local, query_mask=query_mask, query_scales=query_scales, query_positions=query_positions, 
            gallery_global=gallery_global, gallery_local=gallery_local, gallery_mask=gallery_mask, gallery_scales=gallery_scales, gallery_positions=gallery_positions, 
            ks=recall, 
            gnd=query_loader.dataset.gnd_data,
        )
        metrics = evaluate_function()
    return metrics 

In [38]:
model.eval()
device = next(model.parameters()).device
to_device = lambda x: x.to(device, non_blocking=True)

query_global, query_local, query_mask, query_scales, query_positions, query_names = [], [], [], [], [], []
gallery_global, gallery_local, gallery_mask, gallery_scales, gallery_positions, gallery_names = [], [], [], [], [], []

with torch.no_grad():
    for entry in tqdm(query_loader, desc='Extracting query features', leave=False, ncols=80):
        q_global, q_local, q_mask, q_scales, q_positions, _, q_names = entry
        query_global.append(q_global.cpu())
        query_local.append(q_local.cpu())
        query_mask.append(q_mask.cpu())
        query_scales.append(q_scales.cpu())
        query_positions.append(q_positions.cpu())
        query_names.extend(list(q_names))

    query_global    = torch.cat(query_global, 0)
    query_local     = torch.cat(query_local, 0)
    query_mask      = torch.cat(query_mask, 0)
    query_scales    = torch.cat(query_scales, 0)
    query_positions = torch.cat(query_positions, 0)
    
    for entry in tqdm(gallery_loader, desc='Extracting gallery features', leave=False, ncols=80):
        g_global, g_local, g_mask, g_scales, g_positions, _, g_names = entry
        gallery_global.append(g_global.cpu())
        gallery_local.append(g_local.cpu())
        gallery_mask.append(g_mask.cpu())
        gallery_scales.append(g_scales.cpu())
        gallery_positions.append(g_positions.cpu())
        gallery_names.extend(list(g_names))

    gallery_global    = torch.cat(gallery_global, 0)
    gallery_local     = torch.cat(gallery_local, 0)
    gallery_mask      = torch.cat(gallery_mask, 0)
    gallery_scales    = torch.cat(gallery_scales, 0)
    gallery_positions = torch.cat(gallery_positions, 0)

                                                                                

In [39]:
device = next(model.parameters()).device
query_global    = query_global.to(device)
query_local     = query_local.to(device)
query_mask      = query_mask.to(device)
query_scales    = query_scales.to(device)
query_positions = query_positions.to(device)

In [40]:
num_samples, top_k = cache_nn_inds.size()
top_k = min(100, top_k)

In [41]:
print("top_k: ", top_k)    
print("query_global size: ", query_global.shape)
print("query_local size: ", query_local.shape)
print("query_mask size: ", query_mask.shape)
print("query_scales size: ", query_scales.shape)
print("query_positions size: ", query_positions.shape)

top_k:  100
query_global size:  torch.Size([100, 2048])
query_local size:  torch.Size([100, 500, 128])
query_mask size:  torch.Size([100, 500])
query_scales size:  torch.Size([100, 500])
query_positions size:  torch.Size([100, 500, 2])


In [42]:
print("top_k: ", top_k)    
print("gallery_global size: ", gallery_global.shape)
print("gallery_local size: ", gallery_local.shape)
print("gallery_mask size: ", gallery_mask.shape)
print("gallery_scales size: ", gallery_scales.shape)
print("gallery_positions size: ", gallery_positions.shape)

top_k:  100
gallery_global size:  torch.Size([10000, 2048])
gallery_local size:  torch.Size([10000, 500, 128])
gallery_mask size:  torch.Size([10000, 500])
gallery_scales size:  torch.Size([10000, 500])
gallery_positions size:  torch.Size([10000, 500, 2])


In [43]:
gallery_global = gallery_global.view(query_global.size(dim=0), 100, query_global.size(dim=1))
gallery_local = gallery_local.view(query_local.size(dim=0), 100, query_local.size(dim=1), query_local.size(dim=2))
gallery_mask = gallery_mask.view(query_mask.size(dim=0), 100, query_mask.size(dim=1))
gallery_scales = gallery_scales.view(query_scales.size(dim=0), 100, query_scales.size(dim=1))
gallery_positions = gallery_positions.view(query_positions.size(dim=0), 100, query_positions.size(dim=1), query_positions.size(dim=2))

In [44]:
print("top_k: ", top_k)    
print("gallery_global size: ", gallery_global.shape)
print("gallery_local size: ", gallery_local.shape)
print("gallery_mask size: ", gallery_mask.shape)
print("gallery_scales size: ", gallery_scales.shape)
print("gallery_positions size: ", gallery_positions.shape)

top_k:  100
gallery_global size:  torch.Size([100, 100, 2048])
gallery_local size:  torch.Size([100, 100, 500, 128])
gallery_mask size:  torch.Size([100, 100, 500])
gallery_scales size:  torch.Size([100, 100, 500])
gallery_positions size:  torch.Size([100, 100, 500, 2])


In [45]:
torch.cuda.memory_summary(device='cuda:0', abbreviated=False)



In [46]:
gnd_data = pickle_load(osp.join(data_dir, gnd_name))

In [47]:
eval_nn_inds = deepcopy(cache_nn_inds.cpu().data.numpy())
# Exclude the junk images as in DELG (https://github.com/tensorflow/models/blob/44cad43aadff9dd12b00d4526830f7ea0796c047/research/delf/delf/python/detect_to_retrieve/image_reranking.py#L190)
for i in range(num_samples):
    junk_ids = gnd_data['gnd'][i]['junk']
    all_ids = eval_nn_inds[i]
    pos = np.in1d(all_ids, junk_ids)
    neg = np.array([not x for x in pos])
    new_ids = np.concatenate([np.arange(len(all_ids))[neg], np.arange(len(all_ids))[pos]])
    new_ids = all_ids[new_ids]
    eval_nn_inds[i] = new_ids
eval_nn_inds = torch.from_numpy(eval_nn_inds)

In [48]:
eval_nn_inds.shape

torch.Size([100, 100])

In [49]:
i = 0
nnids = eval_nn_inds[:, i]
nnids, nnids.shape, gallery_global.shape

(tensor([78, 97, 96,  0,  8, 95, 43, 98, 17, 43, 88, 63, 24, 79, 76, 50,  8, 40,
         67, 81, 95, 26, 69, 70,  0, 71, 30, 56, 20, 17, 84,  7,  5, 73, 61, 78,
         51,  1, 36, 99, 46, 20,  0, 65, 48, 60, 29, 43, 40, 63, 34, 20,  9, 74,
         65, 93, 23, 76, 48, 94, 22,  2,  8, 46, 98, 47, 22, 64, 98, 29, 92, 57,
         58, 90, 69, 30, 40, 13,  6, 15, 58, 58, 62, 59, 68, 76,  4, 67, 52, 16,
         94, 21, 76, 94,  1, 82, 82, 73, 65, 69]),
 torch.Size([100]),
 torch.Size([100, 100, 2048]))

In [50]:
torch.from_numpy(np.stack([gallery_global[i][nnids[i]] for i in range(len(nnids))], axis=0)).shape


torch.Size([100, 2048])

In [51]:
import gc

In [52]:
torch.cuda.empty_cache()
gc.collect()

165

In [53]:
torch.cuda.get_device_properties(0).total_memory/1e-10

3.4089730048e+20

In [54]:
torch.cuda.memory_reserved(0)/1e-10

3.9845888e+17

In [55]:
torch.cuda.memory_allocated(0)/1e-10

3.6243456e+17

In [56]:
scores = []

for i in tqdm(range(1)):
    nnids = eval_nn_inds[:, i]
    index_global        = []
    index_local         = []
    index_mask          = []
    index_scales        = []
    index_positions     = []
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    for iterator in range(nnids.size(dim=0)):
        index_global.append(gallery_global[iterator, nnids[iterator]])
        index_local.append(gallery_local[iterator, nnids[iterator]])
        index_mask.append(gallery_mask[iterator, nnids[iterator]])
        index_scales.append(gallery_scales[iterator, nnids[iterator]])
        index_positions.append(gallery_positions[iterator, nnids[iterator]])
        
    torch.cuda.empty_cache()
    gc.collect()
    index_global = np.stack(index_global, axis=0)
    index_global = torch.from_numpy(index_global)
    index_global = index_global.to(device)
    
    torch.cuda.empty_cache()
    gc.collect()
    index_local = np.stack(index_local, axis=0)
    index_local = torch.from_numpy(index_local)
    index_local = index_local.to(device)
        
    torch.cuda.empty_cache()
    gc.collect()
    index_mask = np.stack(index_mask, axis=0)
    index_mask = torch.from_numpy(index_mask)
    index_mask = index_mask.to(device)
    
    torch.cuda.empty_cache()
    gc.collect()
    index_scales = np.stack(index_scales, axis=0)
    index_scales = torch.from_numpy(index_scales)
    index_scales = index_scales.to(device)
        
    torch.cuda.empty_cache()
    gc.collect()
    index_positions = np.stack(index_positions, axis=0)
    index_positions = torch.from_numpy(index_positions)
    index_positions = index_positions.to(device)

    print("index_global size: ", index_global.shape)
    print("index_local size: ", index_local.shape)
    print("index_mask size: ", index_mask.shape)
    print("index_scales size: ", index_scales.shape)
    print("index_positions size: ", index_positions.shape)
    
    
    print("index_global device: ", index_global.get_device())
    print("index_local device: ", index_local.get_device())
    print("index_mask device: ", index_mask.get_device())
    print("index_scales device: ", index_scales.get_device())
    print("index_positions device: ", index_positions.get_device())
    
    current_scores = model(
        query_global, query_local, query_mask, query_scales, query_positions,
        index_global,
        index_local,
        index_mask,
        index_scales,
        index_positions)

  0%|          | 0/1 [00:00<?, ?it/s]

index_global size:  torch.Size([100, 2048])
index_local size:  torch.Size([100, 500, 128])
index_mask size:  torch.Size([100, 500])
index_scales size:  torch.Size([100, 500])
index_positions size:  torch.Size([100, 500, 2])
index_global device:  0
index_local device:  0
index_mask device:  0
index_scales device:  0
index_positions device:  0


100%|██████████| 1/1 [00:00<00:00,  1.11it/s]


In [57]:
index_global    = []
for idx in range(nnids.size(dim=0)):
    index_global.append(gallery_global[idx, nnids[idx]])
index_global = np.stack(index_global, axis=0)
index_global = torch.from_numpy(index_global).cuda("cuda:0")
index_local     = []
for idx in range(nnids.size(dim=0)):
    index_local.append(gallery_local[idx, nnids[idx]])
index_local = np.stack(index_local, axis=0)
index_local = torch.from_numpy(index_local).cuda("cuda:0")
index_mask     = []
for idx in range(nnids.size(dim=0)):
    index_mask.append(gallery_mask[idx, nnids[idx]])
index_mask = np.stack(index_mask, axis=0)
index_mask = torch.from_numpy(index_mask).cuda("cuda:0")
index_scales     = []
for idx in range(nnids.size(dim=0)):
    index_scales.append(gallery_scales[idx, nnids[idx]])
index_scales = np.stack(index_scales, axis=0)
index_scales = torch.from_numpy(index_scales).cuda("cuda:0")
#index_scales    = gallery_scales[:, nnids]
index_positions     = []
for idx in range(nnids.size(dim=0)):
    index_positions.append(gallery_positions[idx, nnids[idx]])
index_positions = np.stack(index_positions, axis=0)
index_positions = torch.from_numpy(index_positions).cuda("cuda:0")
#index_positions = gallery_positions[:, nnids]

In [58]:
index_global.shape, index_local.shape, index_mask.shape, index_scales.shape, index_positions.shape

(torch.Size([100, 2048]),
 torch.Size([100, 500, 128]),
 torch.Size([100, 500]),
 torch.Size([100, 500]),
 torch.Size([100, 500, 2]))

In [59]:
query_global
#query_local, 
#query_mask, 
#query_scales, 
#query_positions

tensor([[-0.0106, -0.0303, -0.0169,  ...,  0.0108,  0.0321,  0.0226],
        [ 0.0081,  0.0088,  0.0147,  ..., -0.0036,  0.0040,  0.0296],
        [ 0.0081,  0.0088,  0.0147,  ..., -0.0036,  0.0040,  0.0296],
        ...,
        [-0.0107, -0.0030, -0.0060,  ...,  0.0003,  0.0077,  0.0132],
        [-0.0107, -0.0030, -0.0060,  ...,  0.0003,  0.0077,  0.0132],
        [-0.0107, -0.0030, -0.0060,  ...,  0.0003,  0.0077,  0.0132]],
       device='cuda:0')

In [60]:
torch.cuda.empty_cache()

In [61]:
device

device(type='cuda', index=0)

In [62]:
index_scales.to(device)

tensor([[4, 4, 2,  ..., 0, 2, 0],
        [2, 0, 4,  ..., 2, 0, 0],
        [2, 0, 4,  ..., 2, 0, 0],
        ...,
        [2, 2, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [2, 2, 0,  ..., 2, 0, 0]], device='cuda:0')

In [63]:
index_global.to(device),
index_local.to(device),
index_mask.to(device),
index_scales.to(device),
index_positions.to(device)

tensor([[[0.1768, 0.0000],
         [0.1250, 0.0000],
         [0.0884, 0.0000],
         ...,
         [0.6629, 0.1776],
         [0.4375, 0.3516],
         [0.8750, 0.0251]],

        [[0.6875, 0.1500],
         [0.3438, 0.9750],
         [0.5000, 0.1000],
         ...,
         [0.1768, 0.9899],
         [0.3977, 0.9546],
         [0.7071, 0.2828]],

        [[0.6875, 0.1500],
         [0.3438, 0.9750],
         [0.5000, 0.1000],
         ...,
         [0.1768, 0.9899],
         [0.3977, 0.9546],
         [0.7071, 0.2828]],

        ...,

        [[0.3750, 0.6015],
         [0.1250, 0.6015],
         [0.1250, 0.6617],
         ...,
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.7071, 0.1860],
         [0.6875, 0.1534],
         [0.4375, 0.5699],
         ...,
         [0.1562, 0.6575],
         [0.4861, 0.0930],
         [0.5312, 0.5918]],

        [[0.8839, 0.6675],
         [0.9375, 0.5664],
         [0.9281, 0.6007],
         ...,
 

In [64]:
scores = []
for i in tqdm(range(top_k)):
    nnids = eval_nn_inds[:, i]
    index_global    = []
    for idx in range(nnids.size(dim=0)):
        index_global.append(gallery_global[idx, nnids[idx]])
    index_global = np.stack(index_global, axis=0)
    index_global = torch.from_numpy(index_global)
    index_local     = []
    for idx in range(nnids.size(dim=0)):
        index_local.append(gallery_local[idx, nnids[idx]])
    index_local = np.stack(index_local, axis=0)
    index_local = torch.from_numpy(index_local)
    index_mask     = []
    for idx in range(nnids.size(dim=0)):
        index_mask.append(gallery_mask[idx, nnids[idx]])
    index_mask = np.stack(index_mask, axis=0)
    index_mask = torch.from_numpy(index_mask)
    index_scales     = []
    for idx in range(nnids.size(dim=0)):
        index_scales.append(gallery_scales[idx, nnids[idx]])
    index_scales = np.stack(index_scales, axis=0)
    index_scales = torch.from_numpy(index_scales)
    index_positions     = []
    for idx in range(nnids.size(dim=0)):
        index_positions.append(gallery_positions[idx, nnids[idx]])
    index_positions = np.stack(index_positions, axis=0)
    index_positions = torch.from_numpy(index_positions)
    
    index_global = index_global.to(device)
    index_local = index_local.to(device)
    index_mask = index_mask.to(device)
    index_scales = index_scales.to(device)
    index_positions = index_positions.to(device)

    print("index_global size: ", index_global.shape)
    print("index_local size: ", index_local.shape)
    print("index_mask size: ", index_mask.shape)
    print("index_scales size: ", index_scales.shape)
    print("index_positions size: ", index_positions.shape)
    
    current_scores = model(
        query_global, query_local, query_mask, query_scales, query_positions,
        index_global.to(device),
        index_local.to(device),
        index_mask.to(device),
        index_scales.to(device),
        index_positions.to(device))

  0%|          | 0/100 [00:00<?, ?it/s]

index_global size:  torch.Size([100, 2048])
index_local size:  torch.Size([100, 500, 128])
index_mask size:  torch.Size([100, 500])
index_scales size:  torch.Size([100, 500])
index_positions size:  torch.Size([100, 500, 2])





RuntimeError: CUDA out of memory. Tried to allocate 394.00 MiB (GPU 0; 31.75 GiB total capacity; 29.92 GiB already allocated; 16.00 MiB free; 30.58 GiB reserved in total by PyTorch)

In [65]:
index_global.shape, index_local.shape, index_mask.shape, index_scales.shape, index_positions.shape

(torch.Size([100, 2048]),
 torch.Size([100, 500, 128]),
 torch.Size([100, 500]),
 torch.Size([100, 500]),
 torch.Size([100, 500, 2]))

In [66]:
gallery_global[:, nnids].shape

torch.Size([100, 100, 2048])

In [68]:
scores = []
for i in tqdm(range(top_k)):
    print("gallery_global size: ", gallery_global.shape)
    print("gallery_local size: ", gallery_local.shape)
    print("gallery_mask size: ", gallery_mask.shape)
    print("gallery_scales size: ", gallery_scales.shape)
    print("gallery_positions size: ", gallery_positions.shape)
    nnids = eval_nn_inds[:, i]
    index_global    = []
    for idx in range(nnids.size(dim=0)):
        index_global.append(gallery_global[idx, nnids[idx]])
    index_global = np.stack(index_global, axis=0)
    index_global = torch.from_numpy(index_global)
    index_local     = []
    for idx in range(nnids.size(dim=0)):
        index_local.append(gallery_local[idx, nnids[idx]])
    index_local = np.stack(index_local, axis=0)
    index_local = torch.from_numpy(index_local)
    index_mask     = []
    for idx in range(nnids.size(dim=0)):
        index_mask.append(gallery_mask[idx, nnids[idx]])
    index_mask = np.stack(index_mask, axis=0)
    index_mask = torch.from_numpy(index_mask)
    index_scales     = []
    for idx in range(nnids.size(dim=0)):
        index_scales.append(gallery_scales[idx, nnids[idx]])
    index_scales = np.stack(index_scales, axis=0)
    index_scales = torch.from_numpy(index_scales)
    #index_scales    = gallery_scales[:, nnids]
    index_positions     = []
    for idx in range(nnids.size(dim=0)):
        index_positions.append(gallery_positions[idx, nnids[idx]])
    index_positions = np.stack(index_positions, axis=0)
    index_positions = torch.from_numpy(index_positions)
    torch.cuda.empty_cache()
    print("index_global size: ", index_global.shape)
    print("index_local size: ", index_local.shape)
    print("index_mask size: ", index_mask.shape)
    print("index_scales size: ", index_scales.shape)
    print("index_positions size: ", index_positions.shape)
    
    print("index_global device: ", index_global.get_device())
    print("index_local device: ", index_local.get_device())
    print("index_mask device: ", index_mask.get_device())
    print("index_scales device: ", index_scales.get_device())
    print("index_positions device: ", index_positions.get_device())
    
    index_global.to(device),
    index_local.to(device),
    index_mask.to(device),
    index_scales.to(device),
    index_positions.to(device)

    current_scores = model(
        query_global, query_local, query_mask, query_scales, query_positions,
        index_global.to(device),
        index_local.to(device),
        index_mask.to(device),
        index_scales.to(device),
        index_positions.to(device))
    scores.append(current_scores.cpu().data)
scores = torch.stack(scores, -1) # 70 x 100
closest_dists, indices = torch.sort(scores, dim=-1, descending=True)
closest_indices = torch.gather(medium_nn_inds, -1, indices)
ranks = deepcopy(medium_nn_inds)
ranks[:, :top_k] = deepcopy(closest_indices)
ranks = ranks.cpu().data.numpy().T
# pickle_save('medium_nn_inds.pkl', ranks.T)
out = compute_metrics('viquae', ranks, gnd['gnd'], kappas=ks)

  0%|          | 0/100 [00:00<?, ?it/s]

gallery_global size:  torch.Size([100, 100, 2048])
gallery_local size:  torch.Size([100, 100, 500, 128])
gallery_mask size:  torch.Size([100, 100, 500])
gallery_scales size:  torch.Size([100, 100, 500])
gallery_positions size:  torch.Size([100, 100, 500, 2])
index_global size:  torch.Size([100, 2048])
index_local size:  torch.Size([100, 500, 128])
index_mask size:  torch.Size([100, 500])
index_scales size:  torch.Size([100, 500])
index_positions size:  torch.Size([100, 500, 2])
index_global device:  -1
index_local device:  -1
index_mask device:  -1
index_scales device:  -1
index_positions device:  -1





RuntimeError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 31.75 GiB total capacity; 30.31 GiB already allocated; 12.00 MiB free; 30.59 GiB reserved in total by PyTorch)