In [1]:
import numpy as np
import os.path as osp
import pickle, json, random

In [2]:
import os, math
import os.path as osp
from copy import deepcopy
from functools import partial
from pprint import pprint

In [3]:
import sacred
import torch
import torch.nn as nn
import torch.nn.functional as F
from sacred import SETTINGS
from sacred.utils import apply_backspaces_and_linefeeds
from torch.backends import cudnn
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm

from torch.backends import cudnn
# from visdom_logger import VisdomLogger

In [4]:
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from typing import NamedTuple, Optional, List

In [5]:
from models.matcher import MatchERT
from models.ingredient import model_ingredient, get_model

In [6]:
from models.ingredient import model_ingredient, get_model
from utils import state_dict_to_cpu, num_of_trainable_params
from utils import pickle_load, pickle_save
#from utils.data.utils import TripletSampler
from utils import BinaryCrossEntropyWithLogits
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.training import train_one_epoch, evaluate_viquae



In [7]:
from utils import pickle_load
from sacred import Experiment
from utils.data.dataset_ingredient import data_ingredient, get_loaders
from utils.data.dataset import FeatureDataset

In [8]:
ex = sacred.Experiment('RRT Training', ingredients=[data_ingredient, model_ingredient], interactive=True)
# Filter backspaces and linefeeds
SETTINGS.CAPTURE_MODE = 'sys'
ex.captured_out_filter = apply_backspaces_and_linefeeds

In [9]:
epochs = 15
lr = 0.0001
momentum = 0.
nesterov = False
weight_decay = 5e-4
optim = 'adamw'
scheduler = 'multistep'
max_norm = 0.0
seed = 0

visdom_port = None
visdom_freq = 100
cpu = False  # Force training on CPU
cudnn_flag = 'benchmark'
temp_dir = osp.join('outputs', 'temp')

no_bias_decay = False
loss = 'bce'
scheduler_tau = [16, 18]
scheduler_gamma = 0.1

resume = '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt'

In [10]:
epochs, cpu, cudnn_flag, visdom_port, visdom_freq, temp_dir, seed, no_bias_decay, max_norm, resume

(15,
 False,
 'benchmark',
 None,
 100,
 'outputs/temp',
 0,
 False,
 0.0,
 '/mnt/beegfs/home/smessoud/RerankingTransformer/RRT_GLD/rrt_gld_ckpts/r50_gldv1.pt')

In [11]:
def get_optimizer_scheduler(parameters, optim, loader_length, epochs, lr, momentum, nesterov, weight_decay, scheduler, scheduler_tau, scheduler_gamma, lr_step=None):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True if nesterov and momentum else False)
    elif optim == 'adam':
        optimizer = Adam(parameters, lr=lr, weight_decay=weight_decay) 
    else:
        optimizer = AdamW(parameters, lr=lr, weight_decay=weight_decay)
    
    if epochs == 0:
        scheduler = None
        update_per_iteration = None
    elif scheduler == 'cos':
        # scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * loader_length, eta_min=0.000005)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.000001)
        update_per_iteration = False
    elif scheduler == 'warmcos':
        # warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / (epochs * loader_length))) / 2)
        warm_cosine = lambda i: min((i + 1) / 3, (1 + math.cos(math.pi * i / epochs)) / 2)
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_cosine)
        update_per_iteration = False
    elif scheduler == 'multistep':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_tau, gamma=scheduler_gamma)
        update_per_iteration = False
    elif scheduler == 'warmstep':
        warm_step = lambda i: min((i + 1) / 100, 1) * 0.1 ** (i // (lr_step * loader_length))
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_step)
        update_per_iteration = True
    else:
        scheduler = lr_scheduler.StepLR(optimizer, epochs * loader_length)
        update_per_iteration = True

    return optimizer, (scheduler, update_per_iteration)



In [12]:
def get_loss(loss):
    if loss == 'bce':
        return BinaryCrossEntropyWithLogits()
    else:
        raise Exception('Unsupported loss {}'.format(loss))

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
# callback = VisdomLogger(port=visdom_port) if visdom_port else None
if cudnn_flag == 'deterministic':
    setattr(cudnn, cudnn_flag, True)

In [14]:
def read_file(filename):
    with open(filename) as f:
        lines = f.read().splitlines()
    return lines

In [15]:
name = 'tuto_viquae_tuto_r50_gldv1'
set_name = 'train'
train_txt = 'train.txt'
test_txt = ('dev_query.txt', 'dev_selection.txt')
train_data_dir = 'data/viquae_for_rrt'
test_data_dir  = 'data/viquae_for_rrt'
test_gnd_file = 'gnd_tuto.pkl'
train_gnd_file = 'training_gnd_'+set_name+'.pkl'
desc_name = 'r50_gldv1'
sampler = 'triplet'
split_char  = ';;'

In [16]:
def get_sets(desc_name, 
        train_data_dir, test_data_dir, 
        train_txt, test_txt, test_gnd_file, 
        max_sequence_len, split_char):
    ####################################################################################################################################
    train_lines     = read_file(osp.join(train_data_dir, train_txt))
    train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
    train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
    ####################################################################################################################################
    test_gnd_data = None if test_gnd_file is None else pickle_load(osp.join(test_data_dir, test_gnd_file))
    query_lines   = read_file(osp.join(test_data_dir, test_txt[0]))
    gallery_lines = read_file(osp.join(test_data_dir, test_txt[1]))
    query_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in query_lines]
    gallery_samples = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in gallery_lines]
    gallery_set = FeatureDataset(test_data_dir, gallery_samples, desc_name, max_sequence_len)
    query_set   = FeatureDataset(test_data_dir, query_samples,   desc_name, max_sequence_len, gnd_data=test_gnd_data)
        
    return (train_set, query_train_set), (query_set, gallery_set)


In [17]:
def get_loaders(desc_name, train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name,
    train_gnd_file,
    num_candidates=100,):

    (train_set, query_train_set), (query_set, gallery_set) = get_sets(desc_name=desc_name, 
        train_data_dir=train_data_dir, test_data_dir=train_data_dir, 
        train_txt=train_txt, test_txt=test_txt, test_gnd_file=test_gnd_file, 
        max_sequence_len=max_sequence_len, split_char=split_char)

    if sampler == 'random':
        train_sampler = BatchSampler(RandomSampler(train_set), batch_size=batch_size, drop_last=False)
    elif sampler == 'triplet':
        s_name = set_name
        if s_name != '':
            s_name = set_name + '_'
        def map_nnids_labels(train_data_dir, train_gnd_file, s_categories):
            gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))
            selection_gallery = gnd['simlist']
            s_categories = s_categories.reshape(np.array(selection_gallery).shape)
            selection_ids_to_cat_dict = [{k: s_categories[i][k] for k in range(len(selection_gallery[i]))} for i in range(len(selection_gallery))]

            return selection_ids_to_cat_dict
        s_path = train_data_dir+'/'+set_name+'_s_categories.txt'
        print('s_path: ', s_path)
        s_categories = np.loadtxt(s_path, dtype='int64')
        print('s_categories: ', s_categories.shape)
        map_nnids_labels = map_nnids_labels(train_data_dir, train_gnd_file, s_categories)
        train_nn_inds = osp.join(train_data_dir, 'training_' + s_name+'nn_inds_%s.pkl'%desc_name)
        train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates, map_nnids_labels)
    else:
        raise ValueError('Invalid choice of sampler ({}).'.format(sampler))
    train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory)
    query_train_loader = DataLoader(query_train_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
        
    query_loader   = DataLoader(query_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)
    gallery_loader = DataLoader(gallery_set, batch_size=test_batch_size, num_workers=num_workers, pin_memory=pin_memory)

    return MetricLoaders(train=train_loader, query_train=query_train_loader, query=query_loader, gallery=gallery_loader, num_classes=len(train_set.categories),set_name=set_name), recalls


In [18]:
train_data_dir = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt'

In [19]:
class MetricLoaders(NamedTuple):
    train: DataLoader
    num_classes: int
    query: DataLoader
    query_train: DataLoader
    set_name: str = ''
    gallery: Optional[DataLoader] = None

In [20]:
batch_size      = 32
test_batch_size = 32
max_sequence_len = 500
num_workers = 8  # number of workers used ot load the data
pin_memory  = True  # use the pin_memory option of DataLoader 
num_candidates = 100
recalls = [1, 5, 6, 10]

In [21]:
class TripletSampler():
    def __init__(self, labels, batch_size, nn_inds_path, num_candidates, map_nnids_labels):
        self.batch_size     = batch_size
        self.num_candidates = num_candidates
        self.cache_nn_inds  = pickle_load(nn_inds_path)
        self.labels = labels
        self.map_nnids_labels = map_nnids_labels
        print('nn_inds_path: ', nn_inds_path)
        print('labels len: ', len(labels))
        assert (len(self.cache_nn_inds) == len(labels))
        #############################################################################
        ## Collect valid tuples
        valids = np.zeros_like(labels)
        for i in range(len(self.cache_nn_inds)):
            nnids = self.cache_nn_inds[i]
            query_label = labels[i]
            index_labels = np.array([map_nnids_labels[i][j] for j in nnids])
            #index_labels = np.array([labels[j] for j in nnids])
            positives = np.where(index_labels == query_label)[0]
            if len(positives) < 1:
                continue
            valids[i] = 1
        self.valids = np.where(valids > 0)[0]
        self.num_samples = len(self.valids)

    def __iter__(self):
        batch = []
        cands = torch.randperm(self.num_samples).tolist()
        for i in range(len(cands)):
            anchor_idx = self.valids[cands[i]]
            anchor_label = self.labels[anchor_idx]
            nnids = self.cache_nn_inds[anchor_idx]

            positive_inds = [j for j in nnids if self.map_nnids_labels[anchor_idx][j] == anchor_label]
            negative_inds = [j for j in nnids if self.map_nnids_labels[anchor_idx][j] != anchor_label]
            assert(len(positive_inds) > 0)
            assert(len(negative_inds) > 0)

            random.shuffle(positive_inds)
            random.shuffle(negative_inds)

            batch.append(anchor_idx)
            batch.append(positive_inds[0]) 
            batch.append(negative_inds[0])

            if len(batch) >= self.batch_size:
                yield batch
                batch = []
                
        if len(batch) > 0:
            yield batch

    def __len__(self):
        return (self.num_samples * 3 + self.batch_size - 1) // self.batch_size



In [22]:
nn_inds_path = '/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/training_train_nn_inds_r50_gldv1.pkl'
cache_nn_inds  = pickle_load(nn_inds_path)
cache_nn_inds.shape

(13899, 100)

In [23]:
torch.manual_seed(seed)
loaders, recall_ks = get_loaders('r50_gldv1', train_data_dir, 
    batch_size, test_batch_size, 
    num_workers, pin_memory, 
    sampler, recalls, set_name, train_gnd_file,
    num_candidates=100)

s_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/train_s_categories.txt
s_categories:  (1389900,)
nn_inds_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/training_train_nn_inds_r50_gldv1.pkl
labels len:  13899


In [24]:
def get_model(num_global_features, num_local_features, seq_len, dim_K, dim_feedforward, nhead, num_encoder_layers, dropout, activation, normalize_before):
    return MatchERT(d_global=num_global_features, d_model=num_local_features, seq_len=seq_len, d_K=dim_K, nhead=nhead, num_encoder_layers=num_encoder_layers, 
            dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, normalize_before=normalize_before)

In [25]:
name = 'rrt'
num_global_features = 2048  
num_local_features = 128  
seq_len = 1004
dim_K = 256
dim_feedforward = 1024
nhead = 4
num_encoder_layers = 6
dropout = 0.0 
activation = "relu"
normalize_before = False

In [26]:
torch.manual_seed(seed+1)
model = get_model(num_global_features,num_local_features,seq_len,dim_K,dim_feedforward,nhead,num_encoder_layers,dropout,activation,normalize_before)

In [27]:
if resume is not None:
    checkpoint = torch.load(resume, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state'], strict=True)
print('# of trainable parameters: ', num_of_trainable_params(model))
class_loss = get_loss(loss)
nn_inds_path = osp.join(loaders.query.dataset.data_dir, loaders.set_name + '_nn_inds_%s.pkl'%loaders.query.dataset.desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()

# of trainable parameters:  2243201


In [28]:
torch.manual_seed(seed+2)
model.to(device)
model = nn.DataParallel(model)
parameters = []
if no_bias_decay:
    parameters.append({'params': [par for par in model.parameters() if par.dim() != 1]})
    parameters.append({'params': [par for par in model.parameters() if par.dim() == 1], 'weight_decay': 0})
else:
    parameters.append({'params': model.parameters()})
optimizer, scheduler = get_optimizer_scheduler(parameters=parameters, loader_length=len(loaders.train),
                                               optim=optim, epochs=epochs, lr=lr,
                                               momentum=momentum, nesterov=nesterov, weight_decay=weight_decay,
                                               scheduler=scheduler, scheduler_tau=scheduler_tau, 
                                               scheduler_gamma=scheduler_gamma, lr_step=None)
if resume is not None and checkpoint.get('optim', None) is not None:
    optimizer.load_state_dict(checkpoint['optim'])
    del checkpoint


In [29]:
torch.manual_seed(seed+3)
# setup partial function to simplify call
eval_function = partial(evaluate_viquae, model=model, 
    cache_nn_inds=cache_nn_inds,
    recall=recall_ks, query_loader=loaders.query, gallery_loader=loaders.gallery)

In [30]:
# result = eval_function()
# pprint(result)
# best_val = (0, result, deepcopy(model.state_dict()))

In [31]:
torch.manual_seed(seed+4)

<torch._C.Generator at 0x7f22d17b7810>

In [32]:
"""
for epoch in range(1):
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, True)

    torch.cuda.empty_cache()
    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)
    
    # validation
    if cudnn_flag == 'benchmark':
        setattr(cudnn, cudnn_flag, False)
    
    torch.cuda.empty_cache()
    result = eval_function()
    
    print('Validation [{:03d}]'.format(epoch)), pprint(result)
    #ex.log_scalar('val.map', result['map'], step=epoch + 1)

    if result['map'] >= best_val[1]['map']:
        print('New best model in epoch %d.'%epoch)
        best_val = (epoch + 1, result, deepcopy(model.state_dict()))
        torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)
"""

"\nfor epoch in range(1):\n    if cudnn_flag == 'benchmark':\n        setattr(cudnn, cudnn_flag, True)\n\n    torch.cuda.empty_cache()\n    train_one_epoch(model=model, loader=loaders.train, class_loss=class_loss, optimizer=optimizer, scheduler=scheduler, max_norm=max_norm, epoch=epoch, freq=visdom_freq, ex=None)\n    \n    # validation\n    if cudnn_flag == 'benchmark':\n        setattr(cudnn, cudnn_flag, False)\n    \n    torch.cuda.empty_cache()\n    result = eval_function()\n    \n    print('Validation [{:03d}]'.format(epoch)), pprint(result)\n    #ex.log_scalar('val.map', result['map'], step=epoch + 1)\n\n    if result['map'] >= best_val[1]['map']:\n        print('New best model in epoch %d.'%epoch)\n        best_val = (epoch + 1, result, deepcopy(model.state_dict()))\n        torch.save({'state': state_dict_to_cpu(best_val[2]), 'optim': optimizer.state_dict()}, save_name)\n"

In [33]:
train_lines     = read_file(osp.join(train_data_dir, train_txt))
train_samples   = [(line.split(split_char)[0], int(line.split(split_char)[1]), int(line.split(split_char)[2]), int(line.split(split_char)[3])) for line in train_lines]
train_set       = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)
query_train_set = FeatureDataset(train_data_dir, train_samples, desc_name, max_sequence_len)

In [34]:
s_name = set_name
if s_name != '':
    s_name = set_name + '_'
def map_nnids_labels(train_data_dir, train_gnd_file, s_categories):
    gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))
    selection_gallery = gnd['simlist']
    s_categories = s_categories.reshape(np.array(selection_gallery).shape)
    selection_ids_to_cat_dict = [{k: s_categories[i][k] for k in range(len(selection_gallery[i]))} for i in range(len(selection_gallery))]
    print(s_categories.shape)

    return selection_ids_to_cat_dict

s_path = train_data_dir+'/'+set_name+'_s_categories.txt'
s_categories = np.loadtxt(s_path, dtype='int64')
map_nnids_labels = map_nnids_labels(train_data_dir, train_gnd_file, s_categories)
train_nn_inds = osp.join(train_data_dir, 'training_' + s_name+'nn_inds_%s.pkl'%desc_name)
train_sampler = TripletSampler(train_set.targets, batch_size, train_nn_inds, num_candidates, map_nnids_labels)

(13899, 100)
nn_inds_path:  /mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/training_train_nn_inds_r50_gldv1.pkl
labels len:  13899


In [35]:
nn_inds_path =  osp.join(train_data_dir, 'training_' + s_name+'nn_inds_%s.pkl'%desc_name)
cache_nn_inds = torch.from_numpy(pickle_load(nn_inds_path)).long()
cache_nn_inds.shape

torch.Size([13899, 100])

In [36]:
len(train_sampler.valids), np.where(train_sampler.valids > 0)[0]

(13691, array([    0,     1,     2, ..., 13688, 13689, 13690]))

In [37]:
for item in train_sampler:
    print(item)

[5082, 16, 35, 2869, 82, 33, 11954, 92, 35, 9266, 41, 80, 5634, 2, 89, 8889, 95, 76, 4815, 50, 87, 9425, 57, 4, 8799, 72, 42, 1152, 59, 40, 5763, 35, 52]
[499, 93, 14, 1993, 11, 13, 1364, 36, 90, 6657, 18, 3, 12119, 18, 82, 3920, 65, 48, 13767, 72, 98, 11946, 5, 32, 1707, 13, 2, 11663, 38, 90, 13199, 75, 40]
[1029, 76, 63, 8097, 5, 30, 10167, 81, 94, 7177, 56, 27, 13774, 99, 64, 1687, 80, 63, 11791, 54, 59, 11271, 43, 28, 7764, 70, 17, 351, 2, 37, 8478, 63, 27]
[13568, 12, 30, 2281, 48, 96, 527, 95, 48, 11903, 78, 89, 354, 2, 34, 4886, 0, 27, 13837, 50, 63, 4531, 32, 85, 1810, 0, 17, 3163, 97, 72, 9886, 97, 88]
[3735, 75, 11, 3640, 19, 96, 6686, 1, 46, 5124, 35, 73, 5451, 49, 62, 8156, 73, 52, 1488, 69, 41, 8911, 20, 81, 12788, 87, 2, 5028, 25, 56, 10678, 46, 70]
[5616, 26, 12, 9038, 66, 38, 3884, 42, 2, 13360, 55, 13, 6467, 87, 17, 11720, 36, 31, 5914, 60, 9, 1415, 73, 69, 6572, 29, 65, 7623, 40, 3, 5563, 53, 52]
[6602, 77, 0, 10751, 69, 35, 8459, 48, 29, 659, 65, 79, 684, 65, 15, 343

[3052, 84, 89, 9364, 92, 62, 990, 82, 83, 5695, 36, 64, 7046, 50, 98, 11639, 69, 77, 1301, 90, 9, 8563, 37, 92, 11334, 8, 13, 13269, 36, 3, 9384, 31, 4]
[6544, 85, 78, 11298, 48, 30, 6420, 5, 80, 11405, 56, 97, 4112, 22, 59, 7687, 24, 9, 11338, 55, 93, 8039, 20, 41, 3289, 12, 52, 672, 66, 78, 1234, 58, 61]
[7242, 71, 72, 3644, 18, 26, 1184, 86, 46, 9559, 96, 48, 4424, 2, 89, 373, 11, 94, 9647, 35, 90, 5434, 4, 53, 5180, 17, 30, 9134, 12, 32, 7893, 85, 47]
[760, 22, 85, 4495, 40, 18, 13475, 58, 89, 1808, 21, 9, 13867, 30, 50, 129, 7, 57, 3860, 80, 1, 13171, 56, 96, 3040, 50, 20, 6236, 19, 35, 11949, 42, 23]
[3012, 49, 39, 4671, 83, 34, 4985, 65, 97, 5475, 49, 61, 8508, 71, 31, 13049, 71, 38, 5849, 71, 7, 2810, 85, 8, 7915, 24, 15, 708, 59, 54, 7077, 18, 9]
[12372, 42, 70, 8670, 22, 9, 4742, 16, 49, 589, 94, 17, 5502, 65, 12, 11365, 26, 80, 988, 90, 79, 1363, 49, 18, 6871, 30, 97, 5639, 61, 60, 2735, 14, 42]
[3857, 47, 14, 3944, 78, 63, 10964, 17, 71, 1800, 0, 29, 10071, 98, 16, 778, 62,

[3127, 21, 17, 4913, 38, 40, 4377, 45, 88, 6372, 76, 3, 8124, 50, 39, 9660, 88, 3, 7048, 48, 60, 1905, 93, 53, 3826, 97, 0, 13282, 98, 14, 10869, 77, 32]
[6700, 64, 7, 2666, 28, 50, 8761, 43, 80, 155, 26, 69, 1061, 47, 89, 11144, 89, 60, 4967, 34, 91, 12410, 50, 95, 4568, 85, 77, 6144, 12, 37, 11829, 78, 62]
[3244, 25, 9, 4136, 96, 32, 1129, 88, 18, 13878, 30, 43, 6256, 41, 51, 7234, 46, 47, 171, 99, 71, 12226, 56, 1, 6793, 63, 72, 3009, 55, 28, 913, 10, 15]
[1316, 88, 92, 4962, 4, 80, 11950, 85, 11, 10862, 70, 10, 13345, 33, 65, 790, 71, 10, 1806, 31, 17, 6996, 59, 53, 10560, 21, 14, 326, 6, 53, 10445, 64, 4]
[6406, 30, 47, 11747, 12, 33, 217, 55, 90, 3324, 14, 71, 12720, 90, 9, 4437, 50, 59, 2387, 2, 43, 5103, 41, 71, 1158, 74, 51, 30, 13, 25, 5014, 78, 72]
[8769, 91, 37, 7746, 49, 81, 2620, 62, 86, 7547, 88, 79, 1072, 58, 11, 12938, 2, 18, 1437, 63, 18, 1841, 12, 92, 8873, 99, 67, 10262, 92, 80, 3926, 49, 4]
[8622, 48, 40, 1132, 71, 24, 1677, 27, 45, 9623, 41, 86, 11152, 3, 63, 785,

[13169, 85, 27, 9912, 83, 65, 11939, 27, 81, 11633, 64, 74, 9092, 93, 78, 10027, 41, 50, 8974, 92, 51, 9905, 69, 88, 5981, 5, 41, 5324, 96, 11, 914, 14, 20]
[5642, 65, 28, 11305, 15, 72, 4626, 17, 51, 7085, 25, 99, 12210, 93, 5, 4773, 35, 32, 5329, 9, 30, 11411, 86, 77, 12676, 2, 61, 13855, 68, 91, 9349, 19, 76]
[6329, 24, 44, 5437, 50, 63, 1030, 10, 26, 2641, 16, 96, 4416, 73, 34, 8056, 65, 37, 11684, 23, 46, 1783, 66, 47, 10221, 35, 97, 11126, 37, 99, 4974, 48, 72]
[4188, 72, 53, 10580, 12, 34, 5267, 32, 18, 11050, 15, 89, 701, 57, 7, 243, 35, 69, 12809, 37, 16, 2280, 59, 9, 8397, 39, 68, 611, 46, 20, 6496, 77, 89]
[12433, 38, 13, 7960, 8, 37, 13806, 51, 56, 460, 71, 38, 5433, 45, 51, 5988, 94, 2, 6259, 80, 25, 9011, 71, 74, 4589, 70, 98, 2783, 57, 16, 8282, 11, 85]
[9400, 66, 68, 7684, 75, 9, 8290, 11, 78, 10062, 44, 87, 1910, 18, 76, 5235, 35, 31, 10876, 83, 4, 397, 11, 59, 7244, 9, 71, 8162, 70, 30, 11805, 15, 69]
[3075, 72, 3, 8690, 70, 24, 12414, 50, 19, 6033, 63, 52, 12865, 83,

[10576, 6, 65, 10095, 9, 3, 899, 4, 70, 7617, 54, 77, 9681, 47, 45, 11095, 27, 25, 1384, 92, 42, 13890, 38, 55, 5913, 72, 19, 12140, 58, 77, 9294, 15, 98]
[6452, 80, 81, 2054, 39, 79, 2073, 28, 56, 4034, 86, 67, 10734, 39, 30, 12163, 32, 88, 12534, 67, 8, 5406, 58, 23, 12674, 80, 57, 10176, 74, 78, 13172, 33, 25]
[8747, 85, 42, 2538, 96, 3, 10001, 10, 77, 11489, 92, 80, 8844, 3, 82, 7480, 25, 94, 4265, 18, 77, 8380, 20, 7, 811, 83, 18, 777, 65, 21, 3824, 13, 45]
[7730, 0, 41, 12357, 66, 17, 8663, 52, 41, 8030, 4, 18, 11210, 79, 5, 13294, 38, 96, 508, 86, 45, 2952, 55, 64, 10441, 31, 29, 13393, 25, 34, 8434, 79, 38]
[2187, 60, 95, 10460, 8, 30, 5788, 91, 34, 1924, 98, 40, 2686, 4, 74, 12824, 0, 17, 5922, 25, 54, 2467, 43, 51, 13064, 73, 68, 6243, 4, 39, 7688, 24, 78]
[11363, 79, 15, 10222, 53, 89, 9669, 82, 17, 7104, 89, 96, 5894, 55, 43, 8457, 57, 70, 13195, 52, 26, 5863, 34, 79, 9674, 95, 73, 6977, 60, 63, 10680, 68, 15]
[12111, 97, 83, 4327, 84, 3, 1134, 30, 73, 12215, 66, 1, 2128, 7

[3135, 31, 34, 10296, 53, 27, 3017, 30, 88, 1473, 53, 18, 1116, 3, 58, 9066, 67, 2, 10097, 20, 87, 4050, 93, 52, 7672, 21, 4, 2377, 17, 56, 6058, 87, 86]
[12423, 22, 78, 9090, 93, 14, 8638, 7, 77, 797, 73, 37, 11001, 69, 85, 13222, 90, 40, 1092, 20, 23, 7907, 92, 79, 8229, 85, 19, 13645, 1, 72, 5703, 36, 92]
[8637, 15, 40, 11002, 49, 55, 53, 2, 73, 5562, 93, 90, 12559, 87, 10, 5560, 40, 65, 9631, 33, 80, 8643, 20, 59, 12477, 56, 84, 8320, 22, 54, 812, 97, 35]
[13418, 28, 91, 737, 70, 72, 11545, 67, 57, 12276, 65, 7, 7453, 65, 6, 1986, 96, 75, 6321, 81, 14, 3841, 43, 9, 9696, 87, 58, 12759, 79, 16, 10685, 83, 6]
[5512, 94, 46, 13134, 18, 46, 2684, 24, 77, 12027, 99, 77, 1626, 81, 78, 9499, 24, 37, 9115, 3, 47, 6455, 87, 95, 1089, 76, 70, 5820, 59, 23, 13149, 41, 99]
[676, 64, 0, 3526, 40, 18, 2552, 66, 17, 3087, 72, 38, 12299, 19, 39, 4217, 71, 98, 11948, 95, 35, 1117, 91, 21, 7156, 56, 57, 11239, 78, 24, 13727, 84, 47]
[8275, 30, 34, 7246, 67, 12, 2440, 69, 70, 1526, 93, 85, 4646, 0, 8

[5510, 78, 35, 11157, 57, 76, 3553, 75, 41, 12542, 92, 11, 11472, 7, 54, 3152, 92, 23, 1984, 89, 45, 10906, 39, 25, 1119, 5, 94, 10233, 40, 82, 10214, 35, 93]
[6250, 4, 97, 8212, 62, 86, 4762, 98, 77, 10845, 55, 82, 10426, 0, 81, 6186, 75, 16, 10914, 1, 11, 12694, 94, 15, 477, 40, 95, 559, 8, 97, 2186, 91, 39]
[7604, 91, 63, 4338, 6, 12, 8203, 17, 88, 8929, 54, 49, 12512, 84, 94, 8495, 90, 69, 7259, 9, 53, 4577, 32, 6, 13692, 85, 12, 12887, 68, 7, 12539, 78, 99]
[2755, 94, 20, 9611, 16, 20, 5332, 36, 23, 13173, 69, 14, 13450, 39, 49, 1143, 12, 80, 12330, 1, 96, 3290, 4, 48, 12883, 33, 35, 13361, 41, 20, 6608, 94, 5]
[2561, 85, 99, 11532, 86, 28, 5973, 7, 45, 9712, 42, 8, 390, 51, 82, 8846, 56, 14, 3402, 80, 66, 10484, 27, 30, 5139, 77, 16, 3030, 45, 26, 54, 29, 46]
[10302, 62, 70, 3565, 88, 42, 4856, 83, 27, 1963, 52, 1, 10135, 90, 1, 11136, 57, 96, 9874, 40, 1, 3970, 73, 69, 1728, 81, 54, 2516, 61, 51, 4147, 14, 49]
[6002, 70, 54, 8855, 98, 15, 8491, 92, 26, 9201, 64, 7, 9782, 67, 77,

[8765, 33, 86, 4651, 51, 83, 4950, 46, 1, 5288, 37, 49, 6483, 16, 15, 12023, 15, 25, 11475, 51, 94, 6721, 67, 36, 8685, 92, 98, 9369, 38, 2, 13037, 10, 17]
[12526, 73, 11, 5494, 19, 62, 5786, 44, 5, 1898, 27, 44, 2705, 90, 76, 13511, 32, 25, 12963, 13, 64, 13490, 13, 92, 11968, 40, 80, 11923, 4, 91, 6834, 94, 10]
[7101, 38, 7, 4552, 87, 28, 6061, 42, 61, 6261, 27, 43, 1204, 72, 78, 5937, 32, 38, 141, 4, 1, 3898, 47, 31, 13598, 3, 79, 7039, 83, 25, 6059, 25, 80]
[732, 69, 63, 5490, 10, 53, 3013, 51, 17, 2505, 94, 4, 4310, 54, 25, 281, 11, 75, 8898, 9, 83, 12160, 33, 13, 10040, 2, 45, 4039, 11, 6, 1870, 59, 54]
[4305, 53, 25, 10905, 92, 99, 10358, 17, 63, 7459, 2, 94, 5186, 25, 61, 911, 12, 32, 13695, 4, 23, 6968, 66, 17, 7493, 88, 93, 8515, 97, 24, 5253, 72, 45]
[2437, 35, 8, 4400, 66, 7, 11524, 36, 55, 5783, 96, 54, 13787, 2, 38, 381, 51, 50, 5216, 81, 36, 3315, 16, 62, 2117, 68, 13, 11941, 8, 23, 578, 71, 97]
[7561, 23, 12, 4407, 92, 53, 116, 94, 40, 3481, 76, 22, 8525, 37, 79, 4842, 

[1603, 27, 50, 5453, 98, 20, 8187, 74, 47, 12033, 16, 10, 13682, 50, 49, 655, 79, 99, 3292, 18, 82, 10009, 16, 71, 1459, 93, 28, 4926, 62, 25, 3355, 88, 49]
[304, 9, 12, 20, 14, 24, 4252, 12, 4, 5791, 39, 10, 11933, 53, 94, 4972, 35, 95, 10168, 64, 80, 4794, 75, 1, 4541, 38, 1, 9397, 63, 8, 1831, 37, 19]
[7874, 76, 26, 9210, 79, 51, 4040, 11, 46, 10054, 1, 5, 2906, 17, 9, 7429, 36, 30, 2665, 16, 93, 5910, 82, 68, 10825, 25, 69, 10923, 18, 25, 6268, 29, 51]
[5108, 98, 69, 3294, 12, 17, 3476, 76, 34, 12521, 60, 95, 6194, 65, 39, 10898, 52, 73, 3029, 62, 29, 11064, 73, 21, 11793, 15, 40, 13808, 57, 34, 6863, 31, 73]
[4928, 55, 77, 12928, 58, 26, 1674, 97, 92, 9318, 23, 95, 10496, 43, 39, 5859, 20, 54, 13148, 38, 6, 2719, 19, 50, 9193, 21, 77, 4530, 35, 84, 713, 16, 37]
[2251, 91, 16, 3938, 79, 42, 9105, 40, 38, 4578, 70, 81, 3768, 75, 37, 13575, 80, 0, 13375, 59, 29, 11135, 45, 19, 4740, 79, 52, 3416, 10, 29, 12230, 13, 4]
[10881, 73, 24, 9473, 51, 40, 4186, 70, 9, 7307, 84, 61, 2942, 6, 

[12380, 98, 83, 7465, 57, 87, 9212, 35, 96, 13124, 11, 41, 4214, 64, 35, 11107, 43, 13, 1321, 88, 80, 4085, 15, 70, 8003, 85, 61, 4489, 88, 8, 12279, 55, 6]
[1136, 81, 74, 13023, 75, 31, 4814, 67, 73, 7582, 67, 50, 11194, 73, 35, 882, 76, 47, 3773, 22, 73, 12095, 67, 85, 2415, 9, 16, 11316, 67, 11, 7520, 56, 46]
[4601, 25, 5, 10552, 23, 89, 10486, 70, 87, 13550, 37, 44, 13582, 2, 56, 6211, 5, 88, 11694, 81, 29, 9003, 87, 62, 4736, 94, 7, 9533, 3, 31, 10910, 15, 22]
[8220, 37, 8, 2652, 18, 65, 8273, 30, 37, 1518, 64, 60, 2244, 0, 16, 9360, 46, 91, 10451, 3, 52]


In [38]:
len(train_sampler)

1284

In [39]:
len(item)

21

In [41]:
gnd =  pickle_load(osp.join(train_data_dir, train_gnd_file))

In [49]:
gnd['gnd'][10]['hard']

[16,
 14,
 12,
 9,
 4,
 98,
 43,
 34,
 53,
 58,
 11,
 13,
 1,
 6,
 2,
 96,
 7,
 8,
 92,
 3,
 94,
 93,
 10,
 97,
 99,
 95,
 91,
 15,
 0]

In [56]:
gnd['simlist'][1], np.where(s_categories.reshape((-1, 100))[1] == 1)

(['512px-Dr._Francis_M._Forster.jpg',
  '512px-Tom_Pettit_of_NBC_News_at_1976_DNC.jpg',
  '512px-Eldon_D._Rudd.jpg',
  '512px-JFK_limousine.png',
  '512px-Irving_June_2019_37_(Ruth_Paine_Home).jpg',
  '512px-Carcano_mod._1891.jpg',
  '512px-J._D._Tippit_in_his_Dallas_Police_Department_photo_distributed_in_1963.jpg',
  '512px-Botham_Jean_Blvd_-_Dallas_Police_HQ_-_June_2021_-_03.jpg',
  '512px-Dallas_Collage_Montage.png',
  '512px-Dallas_Collage_Montage.png',
  '512px-SchoolbookDepository.jpg',
  '512px-Bertram_Chalres_Hill.jpg',
  '512px-Jim_Leavelle_(clear).jpg',
  '512px-Jim_Leavelle_(clear).jpg',
  '512px-Geoff_Edwards.JPG',
  '512px-John_Peel_BBC_cropped.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Jack_Ruby-1.jpg',
  '5

In [60]:
np.where(np.array(gnd['simlist'][1].index("512px-Jack_Ruby-1.jpg")

16

In [59]:
gnd['simlist'][1][6:17]

['512px-J._D._Tippit_in_his_Dallas_Police_Department_photo_distributed_in_1963.jpg',
 '512px-Botham_Jean_Blvd_-_Dallas_Police_HQ_-_June_2021_-_03.jpg',
 '512px-Dallas_Collage_Montage.png',
 '512px-Dallas_Collage_Montage.png',
 '512px-SchoolbookDepository.jpg',
 '512px-Bertram_Chalres_Hill.jpg',
 '512px-Jim_Leavelle_(clear).jpg',
 '512px-Jim_Leavelle_(clear).jpg',
 '512px-Geoff_Edwards.JPG',
 '512px-John_Peel_BBC_cropped.jpg',
 '512px-Jack_Ruby-1.jpg']