In [1]:
import os
import torch
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
sys.path.insert(0, './src')
%matplotlib inline


from metrics import NDCG_binary_at_k_batch, Recall_at_k_batch
from models import MultiVAE, Multi_our_VAE
from training import train_model
from data import Dataset
from args import get_args
import numpy as np
import pandas as pd

import pdb

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import ncvis
import re
import gc

device = 'cuda:1'

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [2]:
args = dotdict({})
args.train_batch_size = 500
args.val_batch_size = 10
args.n_val_samples = 200
args.device = device

args.learning_rate = 1e-4
args.n_epoches = 200
args.l2_coeff = 0.01 / args.train_batch_size
args.print_info_ = 1

args.annealing = True

if args.annealing:
    args.total_anneal_steps = 200000
    args.anneal_cap = 0.2
else:
    args.total_anneal_steps = 0
    args.anneal_cap = 1.

device_zero = torch.tensor(0., device=device, dtype=torch.float32)
device_one = torch.tensor(1., device=device, dtype=torch.float32)
std_normal = torch.distributions.Normal(loc=device_zero,
                                    scale=device_one)

In [4]:
# models = [
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True_lrdec_3e-05_lrenc_0.001.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_3_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_ourHoffman_VAE_K_5_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_3_N_3_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True_lrdec_3e-05_lrenc_0.001.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_True_anneal_False_lrdec_0.001_lrenc_None.pt',
# 'best_model_Multi_our_VAE_K_5_N_5_learnreverse_True_anneal_True_lrdec_0.001_lrenc_None.pt',
# ]

In [5]:
def load_model(path, device):
    device_zero = torch.tensor(0., device=device, dtype=torch.float32)
    device_one = torch.tensor(1., device=device, dtype=torch.float32)
    std_normal = torch.distributions.Normal(loc=device_zero,
                                        scale=device_one)
    uniform = torch.distributions.Uniform(low=device_zero, high=device_one)
    torch_log_2 = torch.tensor(np.log(2), device=device, dtype=torch.float32)
    
    model = torch.load(path, map_location=device)
    
    model.eval()
    model.std_normal = std_normal
    if hasattr(model, 'torch_log_2'):
        model.torch_log_2 = torch_log_2
    if hasattr(model, 'transitions'):
        model.transitions = model.transitions.to(device)
        if type(model.transitions) == type(torch.nn.ModuleList([])):
            for i in range(len(model.transitions)):
                model.transitions[i].device = device
                model.transitions[i].device_zero = device_zero
                model.transitions[i].device_one = device_one
        else:
            model.transitions.device = device
            model.transitions.device_zero = device_zero
            model.transitions.device_one = device_one
    if hasattr(model, 'reverse_kernel'):
        model.reverse_kernel = model.reverse_kernel.to(device)
        model.reverse_kernel.device = device
        model.device_one = device_one
        
    return model

In [6]:
resulting_dict = {'model': [], 'data': [], 'K': [], 'N': [], 'anneal': [], 'learnreverse': [], 'learntransitions': [], 'learnscale': [], 'initstepsize': [],  'data': [], 'recall_5': [], 'std_recall_5': [], 'recall_10': [], 'std_recall_10': [], 'recall_20': [], 'recall_50': [], 'recall_100': [], 'ndcg_5': [], 'ndcg_10': [], 'ndcg_20': [], 'ndcg_50': [], 'ndcg_100': [], 'std_ndcg_100': []}

In [7]:
models_list = [
    [
        './models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
        './models/best_model_Multi_our_VAE_data_ml20m_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_ml20m_K_10_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        
    
    ## To validate 2 models in between, comment everything else, leave only ml20m in the cycle below, and re
        './models/best_model_MultiVAE_data_ml20m_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
#         './models/best_model_Multi_our_VAE_data_ml20m_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
#         './models/best_model_Multi_our_VAE_data_ml20m_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.001_lrenc_0.001_learntransitions_False_initstepsize_0.01.pt',
        './models/best_model_Multi_our_VAE_data_ml20m_K_10_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        
    ],
    [
        './models/best_model_MultiVAE_data_gowalla_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.01.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_10_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        
        './models/best_model_MultiVAE_data_gowalla_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.01.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_3_N_5_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_gowalla_K_10_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
    ],
    
    [
        './models/best_model_MultiVAE_data_foursquare_K_None_N_None_learnreverse_False_anneal_False_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.1.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_2_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_3_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_3_N_3_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_5_N_5_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_10_N_1_learnreverse_False_anneal_False_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        
        
        './models/best_model_MultiVAE_data_foursquare_K_None_N_None_learnreverse_False_anneal_True_lrdec_0.001_lrenc_None_learntransitions_True_initstepsize_0.01.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_2_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_3_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_3_N_3_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_5_N_5_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',
        './models/best_model_Multi_our_VAE_data_foursquare_K_10_N_1_learnreverse_False_anneal_True_lrdec_0.003_lrenc_0.001_learntransitions_False_initstepsize_0.005_learnscale_True.pt',  
    ],

]

In [None]:
pattern_1 = re.compile(r'best_model_(\w*)_data_([\w\d]*)_K_([\d\w]*)_N_([\d\w]*)_learnreverse_(\w*)_anneal_(\w*)_lrdec_([.\d]*)_lrenc_([.\d\w]*)_learntransitions_([\w]*)_initstepsize_([.\d]*)(_learnscale_(True))?\.pt')

for i, data in enumerate(['ml20m', 'gowalla', 'foursquare']): # enumerate(['ml20m']):
    args.data = data
    dataset = Dataset(args, data_dir='./data/')
    models = models_list[i]
    for m in tqdm(models):
        vals = re.findall(pattern_1, m)[0]
        resulting_dict['model'].append(vals[0])
        resulting_dict['data'].append(args.data)
        resulting_dict['K'].append(vals[2])
        resulting_dict['N'].append(vals[3])
        resulting_dict['learnreverse'].append(vals[4])
        resulting_dict['anneal'].append(vals[5])
        resulting_dict['learntransitions'].append(vals[8])
        resulting_dict['initstepsize'].append(vals[9])
        
        if len(vals) > 10:
            resulting_dict['learnscale'].append(vals[11])
        else:
            resulting_dict['learnscale'].append('False')
        
        model = load_model(m, device)
        ndcg_5 = []
        ndcg_10 = []
        ndcg_20 = []
        ndcg_50 = []
        ndcg_100 = []
        recall_5 = []
        recall_10 = []
        recall_20 = []
        recall_50 = []
        recall_100 = []
        for bnum, batch_val in enumerate(dataset.next_val_batch()):
            reshaped_batch = batch_val[0].repeat((args.n_val_samples, 1))
            out = model(reshaped_batch)[0]  ### Here we are using noise when estimate metrics
            with torch.no_grad():
                pred_val = out.detach().view((args.n_val_samples, *batch_val[0].shape)).mean(0)
                X = batch_val[0].cpu().detach().numpy()
                pred_val = pred_val.cpu().detach().numpy()
                pred_val[X.nonzero()] = -np.inf
                ndcg_5.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=5))
                ndcg_10.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=10))
                ndcg_20.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=20))
                ndcg_50.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=50))
                ndcg_100.append(NDCG_binary_at_k_batch(pred_val, batch_val[1], k=100))
                recall_5.append(Recall_at_k_batch(pred_val, batch_val[1], k=5))
                recall_10.append(Recall_at_k_batch(pred_val, batch_val[1], k=10))
                recall_20.append(Recall_at_k_batch(pred_val, batch_val[1], k=20))
                recall_50.append(Recall_at_k_batch(pred_val, batch_val[1], k=50))
                recall_100.append(Recall_at_k_batch(pred_val, batch_val[1], k=100))
            del out
        ndcg_5_dist = np.concatenate(ndcg_5)
        ndcg_10_dist = np.concatenate(ndcg_10)
        ndcg_20_dist = np.concatenate(ndcg_20)
        ndcg_50_dist = np.concatenate(ndcg_50)
        ndcg_100_dist = np.concatenate(ndcg_100)

        recall_5_dist = np.concatenate(recall_5)
        recall_10_dist = np.concatenate(recall_10)
        recall_20_dist = np.concatenate(recall_20)
        recall_50_dist = np.concatenate(recall_50)
        recall_100_dist = np.concatenate(recall_100)

        current_ndcg_5 = ndcg_5_dist.mean()
        current_ndcg_10 = ndcg_10_dist.mean()
        current_ndcg_20 = ndcg_20_dist.mean()
        current_ndcg_50 = ndcg_50_dist.mean()
        current_ndcg_100 = ndcg_100_dist.mean()
        current_std_ndcg_100 = ndcg_100_dist.std() / np.sqrt(len(ndcg_100_dist))

        current_recall_5 = recall_5_dist.mean()
        current_std_recall_5 = recall_5_dist.std() / np.sqrt(len(recall_5_dist))
        current_recall_10 = recall_10_dist.mean()
        current_std_recall_10 = recall_10_dist.std() / np.sqrt(len(recall_10_dist))
        current_recall_20 = recall_20_dist.mean()
        current_recall_50 = recall_50_dist.mean()
        current_recall_100 = recall_100_dist.mean()
        
        resulting_dict['ndcg_5'].append(current_ndcg_5)
        resulting_dict['ndcg_10'].append(current_ndcg_10)
        resulting_dict['ndcg_20'].append(current_ndcg_20)
        resulting_dict['ndcg_50'].append(current_ndcg_50)
        resulting_dict['ndcg_100'].append(current_ndcg_100)
        resulting_dict['std_ndcg_100'].append(current_std_ndcg_100)

        resulting_dict['recall_5'].append(current_recall_5)
        resulting_dict['std_recall_5'].append(current_std_recall_5)
        resulting_dict['recall_10'].append(current_recall_10)
        resulting_dict['std_recall_10'].append(current_std_recall_10)
        resulting_dict['recall_20'].append(current_recall_20)
        resulting_dict['recall_50'].append(current_recall_50)
        resulting_dict['recall_100'].append(current_recall_100)
        del model
        gc.collect()
    del dataset

In [None]:
pd.set_option('display.max_colwidth', 100)

In [74]:
df = pd.DataFrame(resulting_dict)
df = df.query('N == "1" or N=="None"').drop(['learnreverse', 'N', 'learnscale', 'initstepsize', 'ndcg_5', 'ndcg_10', 'ndcg_20', 'recall_100', 'ndcg_50', 'recall_20', 'recall_50', 'learntransitions'], axis=1).sort_values(['data', 'anneal', 'model'])
df.round(3)
df.to_csv('./eval_results_bests.csv', index=False)

Unnamed: 0,model,data,K,anneal,recall_5,std_recall_5,recall_10,std_recall_10,ndcg_100,std_ndcg_100
17,MultiVAE,foursquare,,False,0.059,0.003,0.066,0.003,0.127,0.004
18,Multi_our_VAE,foursquare,2.0,False,0.064,0.004,0.07,0.003,0.136,0.004
19,Multi_our_VAE,foursquare,3.0,False,0.066,0.003,0.068,0.003,0.135,0.004
20,Multi_our_VAE,foursquare,3.0,False,0.068,0.004,0.071,0.003,0.134,0.004
21,Multi_our_VAE,foursquare,5.0,False,0.053,0.003,0.057,0.003,0.116,0.004
22,Multi_our_VAE,foursquare,10.0,False,0.064,0.003,0.068,0.003,0.133,0.004
23,MultiVAE,foursquare,,True,0.061,0.003,0.067,0.003,0.13,0.004
24,Multi_our_VAE,foursquare,2.0,True,0.066,0.004,0.068,0.003,0.136,0.004
25,Multi_our_VAE,foursquare,3.0,True,0.064,0.004,0.067,0.003,0.133,0.004
26,Multi_our_VAE,foursquare,3.0,True,0.066,0.004,0.07,0.003,0.132,0.004


In [16]:
print(df.round(3).to_latex(index=False))

\begin{tabular}{rlllllrrrrrrrr}
\toprule
 index &          model &        data &     K & anneal & learnscale &  recall\_5 &  std\_recall\_5 &  recall\_10 &  std\_recall\_10 &  recall\_100 &  ndcg\_20 &  ndcg\_100 &  std\_ndcg\_100 \\
\midrule
    12 &       MultiVAE &  foursquare &  None &  False &            &     0.059 &         0.003 &      0.065 &          0.003 &       0.204 &    0.085 &     0.128 &         0.004 \\
    13 &  Multi\_our\_VAE &  foursquare &     2 &  False &       True &     0.064 &         0.004 &      0.070 &          0.003 &       0.213 &    0.092 &     0.136 &         0.004 \\
    14 &  Multi\_our\_VAE &  foursquare &     3 &  False &       True &     0.067 &         0.003 &      0.068 &          0.003 &       0.212 &    0.092 &     0.135 &         0.004 \\
    15 &       MultiVAE &  foursquare &  None &   True &            &     0.060 &         0.003 &      0.066 &          0.003 &       0.205 &    0.087 &     0.130 &         0.004 \\
    16 &  Multi\_our\_VAE

In [11]:
df = pd.DataFrame(resulting_dict)
# df.to_csv('./eval_results_bests.csv', index=False)
df.round(3)

Unnamed: 0,model,data,K,N,anneal,learnreverse,learntransitions,learnscale,initstepsize,recall_5,recall_10,recall_20,recall_50,recall_100,ndcg_5,ndcg_10,ndcg_20,ndcg_50,ndcg_100
0,MultiVAE,ml20m,,,True,False,True,,0.1,0.315,0.333,0.397,0.536,0.662,0.322,0.32,0.338,0.386,0.429
1,Multi_our_VAE,ml20m,2.0,1.0,True,False,False,True,0.005,0.321,0.336,0.395,0.533,0.658,0.33,0.325,0.34,0.387,0.43
2,Multi_our_VAE,ml20m,3.0,1.0,True,False,False,True,0.005,0.322,0.333,0.394,0.534,0.66,0.332,0.325,0.34,0.388,0.431
