In [1]:
import os
import torch
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
sys.path.insert(0, './src')
%matplotlib inline


from metrics import NDCG_binary_at_k_batch, Recall_at_k_batch
from models import MultiVAE, MultiDAE, Multi_our_VAE, MultiHoffmanVAE, Multi_ourHoffman_VAE
from training import train_model
from data import Dataset
from args import get_args
import numpy as np
import pandas as pd

import pdb

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import ncvis
import re
import gc

device = 'cuda:1'

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [None]:
args = dotdict({})
args.train_batch_size = 500
args.val_batch_size = 10
args.n_val_samples = 100
args.device = device

args.learning_rate = 1e-4
args.n_epoches = 200
args.l2_coeff = 0.01 / args.train_batch_size
args.print_info_ = 1

args.annealing = True

if args.annealing:
    args.total_anneal_steps = 200000
    args.anneal_cap = 0.2
else:
    args.total_anneal_steps = 0
    args.anneal_cap = 1.

device_zero = torch.tensor(0., device=device, dtype=torch.float32)
device_one = torch.tensor(1., device=device, dtype=torch.float32)
std_normal = torch.distributions.Normal(loc=device_zero,
                                    scale=device_one)

In [None]:
models = [model for model in sorted(os.listdir('./models')) if model.startswith('best') ]
for m in models:
    print('\'' + m + '\',' )

In [None]:
# models = [
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_3e-05_lrenc_0.001.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_3e-05_lrenc_0.001.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# ]

In [None]:
def load_model(path, device):
    device_zero = torch.tensor(0., device=device, dtype=torch.float32)
    device_one = torch.tensor(1., device=device, dtype=torch.float32)
    std_normal = torch.distributions.Normal(loc=device_zero,
                                        scale=device_one)
    uniform = torch.distributions.Uniform(low=device_zero, high=device_one)
    torch_log_2 = torch.tensor(np.log(2), device=device, dtype=torch.float32)
    
    model = torch.load(path, map_location=device)
    
    model.eval()
    model.std_normal = std_normal
    if hasattr(model, 'torch_log_2'):
        model.torch_log_2 = torch_log_2
    if hasattr(model, 'transitions'):
        model.transitions = model.transitions.to(device)
        if type(model.transitions) == type(torch.nn.ModuleList([])):
            for i in range(len(model.transitions)):
                model.transitions[i].device = device
                model.transitions[i].device_zero = device_zero
                model.transitions[i].device_one = device_one
        else:
            model.transitions.device = device
            model.transitions.device_zero = device_zero
            model.transitions.device_one = device_one
    if hasattr(model, 'reverse_kernel'):
        model.reverse_kernel = model.reverse_kernel.to(device)
        model.reverse_kernel.device = device
        model.device_one = device_one
        
    return model

In [None]:
resulting_dict = {'model': [], 'data': [], 'K': [], 'N': [], 'anneal': [], 'learnreverse': [], 'learntransitions': [], 'initstepsize': [],  'data': [], 'recall_5': [], 'recall_10': [], 'recall_20': [], 'recall_50': [], 'recall_100': [], 'ndcg_5': [], 'ndcg_10': [], 'ndcg_20': [], 'ndcg_50': [], 'ndcg_100': []}

In [7]:
# args.data = 'ml20m'
# dataset = Dataset(args, data_dir='./data/')

In [8]:
# './logs/metrics_ml20m_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.txt',
# './logs/metrics_ml20m_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_True_initstepsize_0.1.txt',
# './logs/metrics_ml20m_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.txt',
# './logs/metrics_ml20m_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_True_initstepsize_0.1.txt',
# './logs/metrics_ml20m_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.txt',

In [9]:
models_list = [
#     [
        './models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
    
    
#         './models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
#         './models/best_model_Multi_our_VAE_data_ml20m_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#     ],
    
#     [
#         './models/best_model_MultiVAE_data_foursquare_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
#         './models/best_model_Multi_our_VAE_data_foursquare_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_foursquare_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#     ],
    

#     [
#         './models/best_model_MultiVAE_data_gowalla_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_gowalla_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_gowalla_K_2_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_True_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_gowalla_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_True_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_gowalla_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#     ],
]

In [10]:
pattern_1 = re.compile(r'best_model_(\w*)_data_([\w\d]*)_K_([\d\w]*)_N_([\d\w]*)_learnreverse_(\w*)_anneal_(\w*)_lrdec_([.\d]*)_lrenc_([.\d\w]*)_learntransitions_([\w]*)_initstepsize_([.\d]*)\.pt')

for i, data in enumerate(['gowalla']):#['ml100k', 'ml20m', 'ml25m']):
    args.data = data
    dataset = Dataset(args, data_dir='./data/')
    models = models_list[i]
    for m in tqdm(models):
        vals = re.findall(pattern_1, m)[0]
        resulting_dict['model'].append(vals[0])
        resulting_dict['data'].append(args.data)
        resulting_dict['K'].append(vals[2])
        resulting_dict['N'].append(vals[3])
        resulting_dict['learnreverse'].append(vals[4])
        resulting_dict['anneal'].append(vals[5])
        resulting_dict['learntransitions'].append(vals[8])
        resulting_dict['initstepsize'].append(vals[9])
        
        model = load_model(m, device)
        ndcg_5 = []
        ndcg_10 = []
        ndcg_20 = []
        ndcg_50 = []
        ndcg_100 = []
        recall_5 = []
        recall_10 = []
        recall_20 = []
        recall_50 = []
        recall_100 = []
        for bnum, batch_val in enumerate(dataset.next_val_batch()):
            reshaped_batch = batch_val[0].repeat((args.n_val_samples, 1))
            out = model(reshaped_batch)[0]  ### Here we are using noise when estimate metrics
            with torch.no_grad():
                pred_val = out.detach().view((args.n_val_samples, *batch_val[0].shape)).mean(0)
                X = batch_val[0].cpu().detach().numpy()
                pred_val = pred_val.cpu().detach().numpy()
                pred_val[X.nonzero()] = -np.inf
                ndcg_5.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=5))
                ndcg_10.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=10))
                ndcg_20.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=20))
                ndcg_50.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=50))
                ndcg_100.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=100))
                recall_5.append(Recall_at_k_batch(pred_val, batch_val[1], k=5))
                recall_10.append(Recall_at_k_batch(pred_val, batch_val[1], k=10))
                recall_20.append(Recall_at_k_batch(pred_val, batch_val[1], k=20))
                recall_50.append(Recall_at_k_batch(pred_val, batch_val[1], k=50))
                recall_100.append(Recall_at_k_batch(pred_val, batch_val[1], k=100))
            del out
        ndcg_5_dist = np.concatenate(ndcg_5)
        ndcg_10_dist = np.concatenate(ndcg_10)
        ndcg_20_dist = np.concatenate(ndcg_20)
        ndcg_50_dist = np.concatenate(ndcg_50)
        ndcg_100_dist = np.concatenate(ndcg_100)

        recall_5_dist = np.concatenate(recall_5)
        recall_10_dist = np.concatenate(recall_10)
        recall_20_dist = np.concatenate(recall_20)
        recall_50_dist = np.concatenate(recall_50)
        recall_100_dist = np.concatenate(recall_100)

        current_ndcg_5 = ndcg_5_dist.mean()
        current_ndcg_10 = ndcg_10_dist.mean()
        current_ndcg_20 = ndcg_20_dist.mean()
        current_ndcg_50 = ndcg_50_dist.mean()
        current_ndcg_100 = ndcg_100_dist.mean()

        current_recall_5 = recall_5_dist.mean()
        current_recall_10 = recall_10_dist.mean()
        current_recall_20 = recall_20_dist.mean()
        current_recall_50 = recall_50_dist.mean()
        current_recall_100 = recall_100_dist.mean()
        
        resulting_dict['ndcg_5'].append(current_ndcg_5)
        resulting_dict['ndcg_10'].append(current_ndcg_10)
        resulting_dict['ndcg_20'].append(current_ndcg_20)
        resulting_dict['ndcg_50'].append(current_ndcg_50)
        resulting_dict['ndcg_100'].append(current_ndcg_100)

        resulting_dict['recall_5'].append(current_recall_5)
        resulting_dict['recall_10'].append(current_recall_10)
        resulting_dict['recall_20'].append(current_recall_20)
        resulting_dict['recall_50'].append(current_recall_50)
        resulting_dict['recall_100'].append(current_recall_100)
        del model
        gc.collect()
    del dataset

  a = DCG / IDCG
  recall = tmp / np.minimum(k, X_true_binary.sum(axis=1))
100%|██████████| 5/5 [01:32<00:00, 18.56s/it]


In [11]:
pd.set_option('display.max_colwidth', 100)

In [12]:
df = pd.DataFrame(resulting_dict)
# df.to_csv('./eval_results_bests.csv', index=False)
df.round(3)

Unnamed: 0,model,data,K,N,anneal,learnreverse,learntransitions,initstepsize,recall_5,recall_10,recall_20,recall_50,recall_100,ndcg_5,ndcg_10,ndcg_20,ndcg_50,ndcg_100
0,MultiVAE,gowalla,,,False,False,True,0.01,0.085,0.094,0.121,0.183,0.242,0.098,0.1,0.109,0.131,0.15
1,Multi_our_VAE,gowalla,2.0,1.0,False,False,False,0.01,0.094,0.102,0.126,0.186,0.248,0.103,0.104,0.112,0.134,0.153
2,Multi_our_VAE,gowalla,2.0,1.0,False,True,True,0.01,0.088,0.097,0.123,0.185,0.241,0.098,0.098,0.108,0.13,0.148
3,Multi_our_VAE,gowalla,3.0,1.0,False,True,True,0.01,0.089,0.094,0.12,0.182,0.251,0.102,0.101,0.11,0.132,0.153
4,Multi_our_VAE,gowalla,3.0,1.0,False,False,False,0.01,0.09,0.098,0.122,0.183,0.241,0.101,0.102,0.11,0.131,0.15


In [13]:
print(df.round(3).drop(['learnreverse', 'N', 'initstepsize',], axis=1).to_latex(index=False))

\begin{tabular}{lllllrrrrrrrrrr}
\toprule
         model &     data &     K & anneal & learntransitions &  recall\_5 &  recall\_10 &  recall\_20 &  recall\_50 &  recall\_100 &  ndcg\_5 &  ndcg\_10 &  ndcg\_20 &  ndcg\_50 &  ndcg\_100 \\
\midrule
      MultiVAE &  gowalla &  None &  False &             True &     0.085 &      0.094 &      0.121 &      0.183 &       0.242 &   0.098 &    0.100 &    0.109 &    0.131 &     0.150 \\
 Multi\_our\_VAE &  gowalla &     2 &  False &            False &     0.094 &      0.102 &      0.126 &      0.186 &       0.248 &   0.103 &    0.104 &    0.112 &    0.134 &     0.153 \\
 Multi\_our\_VAE &  gowalla &     2 &  False &             True &     0.088 &      0.097 &      0.123 &      0.185 &       0.241 &   0.098 &    0.098 &    0.108 &    0.130 &     0.148 \\
 Multi\_our\_VAE &  gowalla &     3 &  False &             True &     0.089 &      0.094 &      0.120 &      0.182 &       0.251 &   0.102 &    0.101 &    0.110 &    0.132 &     0.153 \\
 Multi\_

In [12]:
df = pd.DataFrame(resulting_dict)
# df.to_csv('./eval_results_bests.csv', index=False)
df.round(3)

Unnamed: 0,model,data,K,N,anneal,learnreverse,learntransitions,initstepsize,recall_5,recall_10,recall_20,recall_50,recall_100,ndcg_5,ndcg_10,ndcg_20,ndcg_50,ndcg_100
0,MultiVAE,ml20m,,,True,False,True,0.1,0.315,0.334,0.397,0.536,0.661,0.322,0.32,0.337,0.385,0.429
1,MultiVAE,ml20m,,,False,False,True,0.1,0.287,0.303,0.362,0.501,0.627,0.295,0.292,0.308,0.355,0.399
2,Multi_our_VAE,ml20m,2.0,1.0,False,False,False,0.01,0.308,0.329,0.388,0.528,0.654,0.314,0.313,0.33,0.378,0.421
3,Multi_our_VAE,ml20m,3.0,1.0,False,False,False,0.01,0.305,0.324,0.387,0.525,0.648,0.31,0.309,0.327,0.375,0.418
4,MultiVAE,foursquare,,,False,False,True,0.1,0.06,0.066,0.093,0.147,0.205,0.069,0.071,0.084,0.107,0.127
5,Multi_our_VAE,foursquare,2.0,1.0,False,False,False,0.01,0.065,0.064,0.095,0.155,0.206,0.075,0.071,0.087,0.112,0.13
6,Multi_our_VAE,foursquare,3.0,1.0,False,False,False,0.01,0.066,0.07,0.096,0.152,0.205,0.077,0.076,0.089,0.112,0.131


In [15]:
print(df.round(3).drop(['learnreverse', 'N', 'initstepsize',], axis=1).to_latex(index=False))

\begin{tabular}{lllllrrrrrrrrrr}
\toprule
         model &        data &     K & anneal & learntransitions &  recall\_5 &  recall\_10 &  recall\_20 &  recall\_50 &  recall\_100 &  ndcg\_5 &  ndcg\_10 &  ndcg\_20 &  ndcg\_50 &  ndcg\_100 \\
\midrule
      MultiVAE &       ml20m &  None &   True &             True &     0.315 &      0.334 &      0.397 &      0.536 &       0.661 &   0.322 &    0.320 &    0.337 &    0.385 &     0.429 \\
      MultiVAE &       ml20m &  None &  False &             True &     0.287 &      0.303 &      0.362 &      0.501 &       0.627 &   0.295 &    0.292 &    0.308 &    0.355 &     0.399 \\
 Multi\_our\_VAE &       ml20m &     2 &  False &            False &     0.308 &      0.329 &      0.388 &      0.528 &       0.654 &   0.314 &    0.313 &    0.330 &    0.378 &     0.421 \\
 Multi\_our\_VAE &       ml20m &     3 &  False &            False &     0.305 &      0.324 &      0.387 &      0.525 &       0.648 &   0.310 &    0.309 &    0.327 &    0.375 &     0.41

In [13]:
# df = pd.DataFrame(resulting_dict)
# # df.to_csv('./eval_results_bests.csv', index=False)
# df.round(3)

In [12]:
# df = pd.DataFrame(resulting_dict)
# # df.to_csv('./eval_results_bests.csv', index=False)
# df[df['model']!='Multi_ourHoffman_VAE'][df['model']!='MultiDAE'].round(3)
# # df

In [13]:
# print(df[df['model']!='Multi_ourHoffman_VAE'][df['model']!='MultiDAE'].round(3).drop(['learnreverse', 'N'], axis=1)#.to_latex(index=False))

In [14]:
# df[df.noise == 1.].sort_values('ndcg', ascending=False).head(30)

In [15]:
# df[df.noise == 0.].sort_values('ndcg', ascending=False).head(33)

# Parameters output

In [5]:
model = load_model('./models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.pt', 'cuda')
for i, transition in enumerate(model.transitions):
    print(i)
    print('autoreg:', torch.sigmoid(transition.alpha_logit))
    print('stepsize', torch.exp(transition.gamma))
    print('-' * 100)

0
autoreg: tensor(0.9000, device='cuda:0')
stepsize tensor(0.0050, device='cuda:0')
----------------------------------------------------------------------------------------------------
1
autoreg: tensor(0.9000, device='cuda:0')
stepsize tensor(0.0050, device='cuda:0')
----------------------------------------------------------------------------------------------------
2
autoreg: tensor(0.9000, device='cuda:0')
stepsize tensor(0.0050, device='cuda:0')
----------------------------------------------------------------------------------------------------


In [16]:
models_list = [
    './models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
    './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.pt',
    './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.0003_lrenc_0.001_learntransitions_True_initstepsize_0.1.pt',
    './models/best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
    './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_True_initstepsize_0.1.pt',
    './models/best_model_Multi_ourHoffman_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.pt'
]

In [21]:
for bnum, batch_train in enumerate(dataset.next_train_batch()):
    l2 = torch.sum(batch_train ** 2, 1)[..., None]
    x_normed = batch_train / torch.sqrt(torch.max(l2, torch.ones_like(l2) * 1e-12))
    x = model.dropout(x_normed)
    break

for model_name in models_list:
    print(model_name)
    model = load_model(model_name, device)
    enc_out = model.encoder(x).detach()
    mu, logvar = enc_out[:, :model.q_dims[-1]], enc_out[:, model.q_dims[-1]:]
    std = torch.exp(0.5 * logvar)
    print(std)
    print('min std', torch.min(std))
    print('-' * 100)

./models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt
tensor([[1.0022, 0.9958, 0.9977,  ..., 0.6612, 0.9921, 0.9933],
        [1.0005, 1.0043, 1.0120,  ..., 0.7581, 0.9967, 0.9844],
        [0.9957, 1.0051, 1.0160,  ..., 0.4933, 1.0019, 0.9804],
        ...,
        [0.9935, 1.0088, 1.0026,  ..., 0.9374, 1.0044, 1.0035],
        [1.0048, 1.0114, 1.0095,  ..., 0.8276, 0.9971, 0.9945],
        [1.0014, 1.0036, 1.0072,  ..., 0.8943, 0.9962, 0.9930]],
       device='cuda:1')
min std tensor(0.0285, device='cuda:1')
----------------------------------------------------------------------------------------------------
./models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.0003_lrenc_0.001_learntransitions_False_initstepsize_0.005.pt
tensor([[0.9840, 1.0332, 1.0254,  ..., 1.0187, 1.0236, 0.2076],
        [0.9756, 1.0036, 1.0029,  ..., 1.0025, 1.0062, 0.2216],
        [0.9666, 0.9646, 1.0012,  ..., 0.99

# Parameters tuning

In [4]:
args.metric = NDCG_binary_at_k_batch
args.n_epoches = 50

In [5]:
for lr in [0.01, 0.001, 0.0001]:
    print('lr ', lr)
    model = load_model('./models/best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt', 'cuda')
    args.learning_rate = lr

    for p in model.parameters():
        p.requires_grad_(False)
    model.eval()
    fixed_decoder = model.decoder
    del model

    layers = [200, 600, dataset.n_items]

    model = MultiVAE(layers, args=args)
    model.decoder = fixed_decoder
    model.to(device)
    train_model(model, dataset, args)

lr  0.01
Sequential(
  (0): Linear(in_features=20108, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=400, bias=True)
)
Sequential(
  (0): Linear(in_features=200, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=20108, bias=True)
)


  0%|          | 0/50 [00:00<?, ?it/s]

763.05676
493.5768
484.29013


  2%|▏         | 1/50 [00:10<08:28, 10.38s/it]

Best NDCG: 0.41631043509731697
Current NDCG: 0.41631043509731697
463.54306
482.39844
510.4981


  4%|▍         | 2/50 [00:19<08:04, 10.09s/it]

Best NDCG: 0.41631043509731697
Current NDCG: 0.4138572942071415
516.3949
430.68063
476.45618


  6%|▌         | 3/50 [00:29<07:44,  9.88s/it]

Best NDCG: 0.41631043509731697
Current NDCG: 0.4150682787431997
432.9654
462.94284
482.69537


  8%|▊         | 4/50 [00:38<07:31,  9.81s/it]

Best NDCG: 0.416692228130136
Current NDCG: 0.416692228130136
446.52295
479.25552
539.5878


 10%|█         | 5/50 [00:48<07:20,  9.78s/it]

Best NDCG: 0.41822493721940407
Current NDCG: 0.41822493721940407
498.0992
450.91455
442.12695


 12%|█▏        | 6/50 [00:57<07:04,  9.64s/it]

Best NDCG: 0.41822493721940407
Current NDCG: 0.4157217167584744
425.31308
527.37354
420.75342


 14%|█▍        | 7/50 [01:07<06:50,  9.55s/it]

Best NDCG: 0.41822493721940407
Current NDCG: 0.41723965865551965
441.63834
506.97534
453.34097


 16%|█▌        | 8/50 [01:16<06:38,  9.48s/it]

Best NDCG: 0.41822493721940407
Current NDCG: 0.41696528957827567
428.03107
480.46338
443.13422


 18%|█▊        | 9/50 [01:25<06:26,  9.43s/it]

Best NDCG: 0.41822493721940407
Current NDCG: 0.41637291438201235
466.2974
485.98727
445.702


 20%|██        | 10/50 [01:35<06:19,  9.48s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41872035576003064
497.65695
443.57144
476.7284


 22%|██▏       | 11/50 [01:44<06:07,  9.43s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.40821928589195966
492.0634
473.75253
485.07083


 24%|██▍       | 12/50 [01:54<05:57,  9.41s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.3935823094644125
506.2767
431.33783
497.40567


 26%|██▌       | 13/50 [02:03<05:47,  9.38s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.40468618072816953
512.1399
461.3823
515.7036


 28%|██▊       | 14/50 [02:12<05:37,  9.37s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4109761570949583
494.78336
475.4515
411.16495


 30%|███       | 15/50 [02:22<05:27,  9.37s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.40939278200979307
494.15845
493.4303
450.76077


 32%|███▏      | 16/50 [02:31<05:18,  9.36s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4115480623475737
416.8693
456.25928
434.8326


 34%|███▍      | 17/50 [02:40<05:08,  9.34s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41095623864209746
470.85425
474.4855
440.32755


 36%|███▌      | 18/50 [02:50<05:00,  9.40s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4075055320637583
531.9137
576.6005
410.6428


 38%|███▊      | 19/50 [02:59<04:50,  9.37s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4027733394503887
483.50406
493.18765
497.98743


 40%|████      | 20/50 [03:08<04:40,  9.36s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4119445840121538
436.40628
422.42468
477.6472


 42%|████▏     | 21/50 [03:18<04:30,  9.34s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4118891706101248
470.6049
417.9542
518.4358


 44%|████▍     | 22/50 [03:27<04:21,  9.34s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41253563736824683
469.38947
443.94193
525.10986


 46%|████▌     | 23/50 [03:36<04:11,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41236912146225
404.9693
477.33673
407.43466


 48%|████▊     | 24/50 [03:46<04:02,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41205178310562324
472.1053
404.12073
471.79877


 50%|█████     | 25/50 [03:55<03:53,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4119115836004921
457.32584
493.08472
430.45514


 52%|█████▏    | 26/50 [04:04<03:43,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4182441207747557
434.7268
459.25348
457.3405


 54%|█████▍    | 27/50 [04:14<03:34,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4130704371242715
433.4441
482.33975
479.93915


 56%|█████▌    | 28/50 [04:23<03:25,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4178059544283289
458.89465
488.6268
420.29822


 58%|█████▊    | 29/50 [04:32<03:15,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41089189397336523
459.5349
436.976
442.9619


 60%|██████    | 30/50 [04:42<03:06,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41489744696114067
412.03546
472.16592
493.31262


 62%|██████▏   | 31/50 [04:51<02:57,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41627544334913563
482.59448
461.1499
473.14774


 64%|██████▍   | 32/50 [05:00<02:48,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41577254995261304
448.96042
460.09665
398.16974


 66%|██████▌   | 33/50 [05:10<02:38,  9.34s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41349353852889137
497.5729
480.16797
464.068


 68%|██████▊   | 34/50 [05:19<02:29,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4138208675140373
456.14737
425.76715
482.6697


 70%|███████   | 35/50 [05:28<02:20,  9.35s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4156197304278576
438.61133
488.2476
494.66318


 72%|███████▏  | 36/50 [05:38<02:10,  9.35s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41570207822830424
387.51178
490.33212
474.3436


 74%|███████▍  | 37/50 [05:47<02:01,  9.34s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4159786611950237
490.8797
439.35056
509.65662


 76%|███████▌  | 38/50 [05:56<01:52,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41305832058348496
459.40982
465.01038
533.0222


 78%|███████▊  | 39/50 [06:06<01:42,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41061405368707166
482.4204
468.17993
462.17258


 80%|████████  | 40/50 [06:15<01:33,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.414121057225444
433.1036
507.91196
419.19348


 82%|████████▏ | 41/50 [06:24<01:23,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41285008671857365
512.4284
499.11304
503.1141


 84%|████████▍ | 42/50 [06:34<01:14,  9.31s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4065111420818841
471.59485
458.30597
471.25266


 86%|████████▌ | 43/50 [06:43<01:05,  9.32s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4108677823838324
444.7787
466.75778
459.08234


 88%|████████▊ | 44/50 [06:52<00:55,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41149944519272946
454.30627
473.861
448.36298


 90%|█████████ | 45/50 [07:02<00:46,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41508213514194486
490.8526
473.63467
467.9655


 92%|█████████▏| 46/50 [07:11<00:37,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4136204656274582
502.7781
444.20706
510.63327


 94%|█████████▍| 47/50 [07:20<00:27,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4135036943605552
450.47943
409.8124
514.3021


 96%|█████████▌| 48/50 [07:30<00:18,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4137055337919539
451.71307
401.311
445.27063


 98%|█████████▊| 49/50 [07:39<00:09,  9.33s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.4154281961741961
445.40555
472.95425
454.44913


100%|██████████| 50/50 [07:48<00:00,  9.38s/it]

Best NDCG: 0.41872035576003064
Current NDCG: 0.41698482994756886
lr  0.001
Sequential(
  (0): Linear(in_features=20108, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=400, bias=True)
)



  0%|          | 0/50 [00:00<?, ?it/s]

Sequential(
  (0): Linear(in_features=200, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=20108, bias=True)
)
782.51886
480.2633
432.68842


  2%|▏         | 1/50 [00:09<07:51,  9.62s/it]

Best NDCG: 0.38662161600485473
Current NDCG: 0.38662161600485473
495.66968
499.05185
513.42163


  4%|▍         | 2/50 [00:19<07:43,  9.65s/it]

Best NDCG: 0.41432420893457017
Current NDCG: 0.41432420893457017
471.09595
461.53983
479.60965


  6%|▌         | 3/50 [00:29<07:34,  9.67s/it]

Best NDCG: 0.42230070695795247
Current NDCG: 0.42230070695795247
464.45874
490.27377
472.27243


  8%|▊         | 4/50 [00:38<07:26,  9.70s/it]

Best NDCG: 0.4257661677582814
Current NDCG: 0.4257661677582814
469.2689
496.73697
407.6685


 10%|█         | 5/50 [00:48<07:17,  9.72s/it]

Best NDCG: 0.427807123093632
Current NDCG: 0.427807123093632
420.26855
502.60068
468.5354


 12%|█▏        | 6/50 [00:58<07:08,  9.73s/it]

Best NDCG: 0.42874252566664306
Current NDCG: 0.42874252566664306
518.1942
417.0701
445.0361


 14%|█▍        | 7/50 [01:08<06:57,  9.72s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4291353194093543
457.75784
501.37955
396.6067


 16%|█▌        | 8/50 [01:17<06:43,  9.61s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.42869396804125026
409.67407
431.3434
449.9202


 18%|█▊        | 9/50 [01:26<06:30,  9.53s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.42832434879584125
441.21844
393.03815
463.2394


 20%|██        | 10/50 [01:36<06:18,  9.47s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4287883753778909
455.50284
480.55414
473.80844


 22%|██▏       | 11/50 [01:45<06:07,  9.42s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4281735316478447
477.74243
444.563
460.47253


 24%|██▍       | 12/50 [01:54<05:56,  9.39s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4283900885164492
434.61337
451.51074
408.9474


 26%|██▌       | 13/50 [02:04<05:46,  9.37s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4283407180831401
461.85413
464.5484
492.98645


 28%|██▊       | 14/50 [02:13<05:36,  9.36s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4285480836072645
490.98718
480.79272
393.06076


 30%|███       | 15/50 [02:22<05:27,  9.34s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.42827806028691684
386.82907
414.00092
446.0724


 32%|███▏      | 16/50 [02:31<05:17,  9.34s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.42891751721232085
392.18137
482.71326
458.2764


 34%|███▍      | 17/50 [02:41<05:07,  9.33s/it]

Best NDCG: 0.4291353194093543
Current NDCG: 0.4279263777506431
448.27243
400.445
489.0945


 36%|███▌      | 18/50 [02:50<05:01,  9.41s/it]

Best NDCG: 0.4294112707072465
Current NDCG: 0.4294112707072465
433.2507
510.789
475.32346


 38%|███▊      | 19/50 [03:00<04:50,  9.38s/it]

Best NDCG: 0.4294112707072465
Current NDCG: 0.4294023707461469
483.1986
511.99338
481.6038


 40%|████      | 20/50 [03:09<04:43,  9.44s/it]

Best NDCG: 0.42990014162171647
Current NDCG: 0.42990014162171647
426.56573
509.5651
456.31595


 42%|████▏     | 21/50 [03:19<04:32,  9.41s/it]

Best NDCG: 0.42990014162171647
Current NDCG: 0.4296106442684806
476.3677
479.4097
471.11237


 44%|████▍     | 22/50 [03:28<04:22,  9.39s/it]

Best NDCG: 0.42990014162171647
Current NDCG: 0.42978173718180823
440.0625
489.23148
472.30777


 46%|████▌     | 23/50 [03:37<04:12,  9.37s/it]

Best NDCG: 0.42990014162171647
Current NDCG: 0.42973217204052355
440.46838
495.0654
405.1298


 48%|████▊     | 24/50 [03:47<04:03,  9.35s/it]

Best NDCG: 0.42990014162171647
Current NDCG: 0.4298220175015862
496.16324
445.12628
479.84546


 50%|█████     | 25/50 [03:56<03:55,  9.44s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.43059959279784826
414.30573
452.61066
492.84012


 52%|█████▏    | 26/50 [04:06<03:45,  9.41s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.43025935112737945
425.48373
484.18808
456.04733


 54%|█████▍    | 27/50 [04:15<03:35,  9.36s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.4302213389958202
445.21793
414.53586
415.1827


 56%|█████▌    | 28/50 [04:24<03:25,  9.34s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.43055656863256564
478.9595
468.5906
474.03937


 58%|█████▊    | 29/50 [04:33<03:15,  9.32s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.4305775896581452
479.62234
469.8659
510.51575


 60%|██████    | 30/50 [04:43<03:06,  9.31s/it]

Best NDCG: 0.43059959279784826
Current NDCG: 0.4305214091290254
435.09717
459.45026
454.5471


 62%|██████▏   | 31/50 [04:52<02:58,  9.41s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.4307957031462108
450.75214
518.9749
428.47934


 64%|██████▍   | 32/50 [05:02<02:49,  9.40s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.430677440124794
517.1311
478.64035
490.82642


 66%|██████▌   | 33/50 [05:11<02:39,  9.38s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.430219897251417
480.67514
459.42078
456.97345


 68%|██████▊   | 34/50 [05:20<02:29,  9.36s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.4302748190228962
488.89185
525.93915
455.1071


 70%|███████   | 35/50 [05:30<02:20,  9.36s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.4304652450016016
435.70337
488.76898
472.30194


 72%|███████▏  | 36/50 [05:39<02:10,  9.34s/it]

Best NDCG: 0.4307957031462108
Current NDCG: 0.43072722249390133
425.31445
431.94498
432.32446


 74%|███████▍  | 37/50 [05:49<02:02,  9.43s/it]

Best NDCG: 0.43092725119860914
Current NDCG: 0.43092725119860914
450.60544
427.26868
480.03406


 76%|███████▌  | 38/50 [05:58<01:52,  9.41s/it]

Best NDCG: 0.43092725119860914
Current NDCG: 0.4306447574509004
463.5559
433.89197
440.63892


 78%|███████▊  | 39/50 [06:08<01:44,  9.48s/it]

Best NDCG: 0.4311333062229762
Current NDCG: 0.4311333062229762
459.4418
382.93707
445.4471


 80%|████████  | 40/50 [06:17<01:34,  9.46s/it]

Best NDCG: 0.4311333062229762
Current NDCG: 0.4308529563782291
511.97003
512.48926
434.3785


 82%|████████▏ | 41/50 [06:27<01:25,  9.49s/it]

Best NDCG: 0.4311660437919053
Current NDCG: 0.4311660437919053
448.1055
487.91583
420.30835


 84%|████████▍ | 42/50 [06:36<01:15,  9.46s/it]

Best NDCG: 0.4311660437919053
Current NDCG: 0.43105280982639926
475.79468
482.73395
484.52417


 86%|████████▌ | 43/50 [06:45<01:05,  9.42s/it]

Best NDCG: 0.4311660437919053
Current NDCG: 0.4308543732761187
432.15128
437.39685
508.4714


 88%|████████▊ | 44/50 [06:55<00:56,  9.47s/it]

Best NDCG: 0.4312483532820965
Current NDCG: 0.4312483532820965
423.16635
459.33444
410.03006


 90%|█████████ | 45/50 [07:04<00:47,  9.44s/it]

Best NDCG: 0.4312483532820965
Current NDCG: 0.4312466431884652
469.0938
483.33835
474.1883


 92%|█████████▏| 46/50 [07:14<00:38,  9.50s/it]

Best NDCG: 0.43178609195001244
Current NDCG: 0.43178609195001244
454.81494
415.48566
439.30804


 94%|█████████▍| 47/50 [07:23<00:28,  9.46s/it]

Best NDCG: 0.43178609195001244
Current NDCG: 0.43168689745150896
463.6924
459.8599
412.23


 96%|█████████▌| 48/50 [07:33<00:18,  9.42s/it]

Best NDCG: 0.43178609195001244
Current NDCG: 0.4316650683023692
448.9584
533.5582
469.78873


 98%|█████████▊| 49/50 [07:42<00:09,  9.40s/it]

Best NDCG: 0.43178609195001244
Current NDCG: 0.43177935789024224
444.88535
416.44446
416.04764


100%|██████████| 50/50 [07:51<00:00,  9.44s/it]

Best NDCG: 0.43178609195001244
Current NDCG: 0.43140715952049724
lr  0.0001
Sequential(
  (0): Linear(in_features=20108, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=400, bias=True)
)



  0%|          | 0/50 [00:00<?, ?it/s]

Sequential(
  (0): Linear(in_features=200, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=20108, bias=True)
)
827.8938
590.48456
534.3836


  2%|▏         | 1/50 [00:09<07:49,  9.57s/it]

Best NDCG: 0.22169994789407302
Current NDCG: 0.22169994789407302
573.11334
484.31256
602.4589


  4%|▍         | 2/50 [00:19<07:42,  9.63s/it]

Best NDCG: 0.27096320488697156
Current NDCG: 0.27096320488697156
531.43774
461.91806
489.44623


  6%|▌         | 3/50 [00:29<07:33,  9.65s/it]

Best NDCG: 0.315709639563538
Current NDCG: 0.315709639563538
565.1189
457.64105
530.9767


  8%|▊         | 4/50 [00:38<07:23,  9.65s/it]

Best NDCG: 0.34152550977874924
Current NDCG: 0.34152550977874924
472.39383
492.26364
496.96155


 10%|█         | 5/50 [00:48<07:14,  9.66s/it]

Best NDCG: 0.35767958488767554
Current NDCG: 0.35767958488767554
509.6482
542.57837
497.9414


 12%|█▏        | 6/50 [00:58<07:05,  9.67s/it]

Best NDCG: 0.3688174430879954
Current NDCG: 0.3688174430879954
545.7222
473.49463
492.88116


 14%|█▍        | 7/50 [01:07<06:56,  9.69s/it]

Best NDCG: 0.37808103625944234
Current NDCG: 0.37808103625944234
479.60303
559.1282
448.22797


 16%|█▌        | 8/50 [01:17<06:48,  9.72s/it]

Best NDCG: 0.3866231003484781
Current NDCG: 0.3866231003484781
447.8226
475.1513
472.25662


 18%|█▊        | 9/50 [01:27<06:38,  9.73s/it]

Best NDCG: 0.39343623283579543
Current NDCG: 0.39343623283579543
414.77426
446.40207
492.43396


 20%|██        | 10/50 [01:37<06:29,  9.74s/it]

Best NDCG: 0.398446554429592
Current NDCG: 0.398446554429592
515.6718
454.57626
479.76694


 22%|██▏       | 11/50 [01:46<06:19,  9.73s/it]

Best NDCG: 0.4024166541297719
Current NDCG: 0.4024166541297719
525.1745
508.63312
465.8779


 24%|██▍       | 12/50 [01:56<06:10,  9.74s/it]

Best NDCG: 0.405950311992509
Current NDCG: 0.405950311992509
562.7871
481.19968
479.96252


 26%|██▌       | 13/50 [02:06<06:00,  9.75s/it]

Best NDCG: 0.4086932720654274
Current NDCG: 0.4086932720654274
461.4553
505.87653
509.20984


 28%|██▊       | 14/50 [02:16<05:51,  9.77s/it]

Best NDCG: 0.4107007743023827
Current NDCG: 0.4107007743023827
585.6129
468.3681
474.55072


 30%|███       | 15/50 [02:25<05:42,  9.78s/it]

Best NDCG: 0.41273007658090494
Current NDCG: 0.41273007658090494
459.02444
493.37457
468.7492


 32%|███▏      | 16/50 [02:35<05:32,  9.77s/it]

Best NDCG: 0.41453761855886023
Current NDCG: 0.41453761855886023
432.12814
472.82944
501.38925


 34%|███▍      | 17/50 [02:45<05:21,  9.75s/it]

Best NDCG: 0.41611462962107415
Current NDCG: 0.41611462962107415
445.3868
449.21317
495.6188


 36%|███▌      | 18/50 [02:55<05:11,  9.74s/it]

Best NDCG: 0.41741226285991245
Current NDCG: 0.41741226285991245
474.54257
478.85974
457.47202


 38%|███▊      | 19/50 [03:04<05:01,  9.74s/it]

Best NDCG: 0.4179302063695628
Current NDCG: 0.4179302063695628
473.70877
434.7447
497.02945


 40%|████      | 20/50 [03:14<04:52,  9.74s/it]

Best NDCG: 0.41948377929258157
Current NDCG: 0.41948377929258157
458.9773
483.5885
471.24966


 42%|████▏     | 21/50 [03:24<04:42,  9.74s/it]

Best NDCG: 0.42045178158384217
Current NDCG: 0.42045178158384217
428.12485
461.4407
481.6281


 44%|████▍     | 22/50 [03:34<04:32,  9.75s/it]

Best NDCG: 0.4212411935547899
Current NDCG: 0.4212411935547899
525.90857
512.5083
489.93582


 46%|████▌     | 23/50 [03:43<04:21,  9.70s/it]

Best NDCG: 0.4219350089905927
Current NDCG: 0.4219350089905927
444.00983
523.5572
439.01273


 48%|████▊     | 24/50 [03:53<04:12,  9.71s/it]

Best NDCG: 0.42255992412886245
Current NDCG: 0.42255992412886245
497.55344
497.66364
439.73535


 50%|█████     | 25/50 [04:03<04:03,  9.72s/it]

Best NDCG: 0.42341567797737645
Current NDCG: 0.42341567797737645
443.46115
450.0165
479.62158


 52%|█████▏    | 26/50 [04:12<03:53,  9.72s/it]

Best NDCG: 0.4239056776070247
Current NDCG: 0.4239056776070247
490.47644
412.3915
429.32614


 54%|█████▍    | 27/50 [04:22<03:43,  9.73s/it]

Best NDCG: 0.42393937947542537
Current NDCG: 0.42393937947542537
438.2827
491.7971
471.76785


 56%|█████▌    | 28/50 [04:32<03:34,  9.74s/it]

Best NDCG: 0.4244668871100178
Current NDCG: 0.4244668871100178
475.41904
426.25565
452.70374


 58%|█████▊    | 29/50 [04:42<03:24,  9.74s/it]

Best NDCG: 0.42512221065736094
Current NDCG: 0.42512221065736094
507.4054
436.80463
404.27426


 60%|██████    | 30/50 [04:51<03:14,  9.74s/it]

Best NDCG: 0.42578881514039424
Current NDCG: 0.42578881514039424
412.39395
448.1411
420.6178


 62%|██████▏   | 31/50 [05:01<03:05,  9.74s/it]

Best NDCG: 0.42628701115684536
Current NDCG: 0.42628701115684536
500.69562
503.62457
445.94675


 64%|██████▍   | 32/50 [05:11<02:55,  9.76s/it]

Best NDCG: 0.4266343032202315
Current NDCG: 0.4266343032202315
419.99277
425.90042
451.46722


 66%|██████▌   | 33/50 [05:21<02:45,  9.75s/it]

Best NDCG: 0.42715423653481577
Current NDCG: 0.42715423653481577
431.70065
439.8859
545.81433


 68%|██████▊   | 34/50 [05:30<02:35,  9.73s/it]

Best NDCG: 0.42729456369165064
Current NDCG: 0.42729456369165064
522.1799
450.32895
527.28687


 70%|███████   | 35/50 [05:40<02:25,  9.73s/it]

Best NDCG: 0.4277152275892275
Current NDCG: 0.4277152275892275
454.60876
435.59515
465.4087


 72%|███████▏  | 36/50 [05:50<02:16,  9.73s/it]

Best NDCG: 0.42804974369158333
Current NDCG: 0.42804974369158333
497.06882
510.586
474.51837


 74%|███████▍  | 37/50 [05:59<02:05,  9.62s/it]

Best NDCG: 0.42804974369158333
Current NDCG: 0.42803985327522187
457.452
496.57422
506.747


 76%|███████▌  | 38/50 [06:09<01:55,  9.62s/it]

Best NDCG: 0.4282300945402134
Current NDCG: 0.4282300945402134
489.67657
503.89124
440.71716


 78%|███████▊  | 39/50 [06:19<01:46,  9.67s/it]

Best NDCG: 0.4286481585917104
Current NDCG: 0.4286481585917104
462.96063
453.88333
446.9413


 80%|████████  | 40/50 [06:28<01:36,  9.70s/it]

Best NDCG: 0.42888021924503084
Current NDCG: 0.42888021924503084
487.79688
519.32477
420.70908


 82%|████████▏ | 41/50 [06:38<01:26,  9.58s/it]

Best NDCG: 0.42888021924503084
Current NDCG: 0.42872126054382714
528.5571
478.07965
491.09323


 84%|████████▍ | 42/50 [06:47<01:16,  9.59s/it]

Best NDCG: 0.4291061563307388
Current NDCG: 0.4291061563307388
448.9629
420.53308
511.9563


 86%|████████▌ | 43/50 [06:57<01:06,  9.51s/it]

Best NDCG: 0.4291061563307388
Current NDCG: 0.42900628592658957
454.22455
485.14368
522.1248


 88%|████████▊ | 44/50 [07:06<00:57,  9.54s/it]

Best NDCG: 0.42922364692967546
Current NDCG: 0.42922364692967546
510.02554
488.83942
474.16266


 90%|█████████ | 45/50 [07:15<00:47,  9.46s/it]

Best NDCG: 0.42922364692967546
Current NDCG: 0.42901505107762694
531.31494
424.42664
446.86783


 92%|█████████▏| 46/50 [07:25<00:38,  9.51s/it]

Best NDCG: 0.4293140874079092
Current NDCG: 0.4293140874079092
502.33246
435.3564
478.65262


 94%|█████████▍| 47/50 [07:35<00:28,  9.65s/it]

Best NDCG: 0.4294888444232336
Current NDCG: 0.4294888444232336
482.9473
427.95587
448.59763


 96%|█████████▌| 48/50 [07:45<00:19,  9.64s/it]

Best NDCG: 0.4297082162678262
Current NDCG: 0.4297082162678262
435.4895
418.33298
521.2349


 98%|█████████▊| 49/50 [07:54<00:09,  9.55s/it]

Best NDCG: 0.4297082162678262
Current NDCG: 0.4293932150163262
504.02423
381.1876
517.6016


100%|██████████| 50/50 [08:03<00:00,  9.68s/it]

Best NDCG: 0.4298818419895272
Current NDCG: 0.4298818419895272



