In [None]:
import os
import torch
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
sys.path.insert(0, './src')
%matplotlib inline


from metrics import NDCG_binary_at_k_batch, Recall_at_k_batch
from models import MultiVAE, MultiDAE, Multi_our_VAE, MultiHoffmanVAE, Multi_ourHoffman_VAE
from training import train_model
from data import Dataset
from args import get_args
import numpy as np
import pandas as pd

import pdb

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import ncvis




device = 'cpu'

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [2]:
args = dotdict({})
args.train_batch_size = 500
args.val_batch_size = 2000
args.data = 'ml20m'
args.device = device

args.learning_rate = 1e-4
args.n_epoches = 200
args.l2_coeff = 0.01 / args.train_batch_size
args.print_info_ = 1

args.annealing = True

if args.annealing:
    args.total_anneal_steps = 200000
    args.anneal_cap = 0.2
else:
    args.total_anneal_steps = 0
    args.anneal_cap = 1.

device_zero = torch.tensor(0., device=device, dtype=torch.float32)
device_one = torch.tensor(1., device=device, dtype=torch.float32)
std_normal = torch.distributions.Normal(loc=device_zero,
                                    scale=device_one)
    
dataset = Dataset(args, data_dir='./data/')

In [3]:
models = [model for model in sorted(os.listdir('./models')) if model.startswith('best') ]
for m in models:
    print('\'' + m + '\',' )

'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False.pt',
'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True.pt',
'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False.pt',
'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True.pt',
'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_False.pt',
'best_model_MultiHoffmanVAE_K_3_N_3_learnreverse_False_anneal_True.pt',
'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_False.pt',
'best_model_MultiHoffmanVAE_K_5_N_5_learnreverse_False_anneal_True.pt',
'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt',
'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt',
'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False.pt',
'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt',
'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False.pt',
'best_model_Multi_ourHoffman_VAE_K_1_N_1_lear

In [4]:
# models = [
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True.pt',
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_True.pt',
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_False.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_True.pt',
# 'best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_True.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_False.pt',
# 'best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True.pt',
# ]

In [3]:
def load_model(path, device):
    device_zero = torch.tensor(0., device=device, dtype=torch.float32)
    device_one = torch.tensor(1., device=device, dtype=torch.float32)
    std_normal = torch.distributions.Normal(loc=device_zero,
                                        scale=device_one)
    uniform = torch.distributions.Uniform(low=device_zero, high=device_one)
    torch_log_2 = torch.tensor(np.log(2), device=device, dtype=torch.float32)
    
    model = torch.load(path, map_location=device)
    
    model.eval()
    model.std_normal = std_normal
    if hasattr(model, 'torch_log_2'):
        model.torch_log_2 = torch_log_2
    if hasattr(model, 'transitions'):
        model.transitions = model.transitions.to(device)
        if type(model.transitions) == type(torch.nn.ModuleList([])):
            for i in range(len(model.transitions)):
                model.transitions[i].device = device
                model.transitions[i].device_zero = device_zero
                model.transitions[i].device_one = device_one
        else:
            model.transitions.device = device
            model.transitions.device_zero = device_zero
            model.transitions.device_one = device_one
    if hasattr(model, 'reverse_kernel'):
        model.reverse_kernel = model.reverse_kernel.to(device)
        model.reverse_kernel.device = device
        model.device_one = device_one
        
    return model

In [6]:
resulting_dict = {'model': [], 'ndcg': [], 'recall': [], 'noise': []}

for noise in [1., 0.]:
    for m in tqdm(models):
        model = load_model(os.path.join('./models', m), device)

        ndcg = []
        recall = []
        for bnum, batch_val in enumerate(dataset.next_val_batch()):
            out = model(batch_val[0], is_training_ph=noise)  ### Here we are using noise when estimate metrics
            pred_val = out[0]
            X = batch_val[0].cpu().detach().numpy()
            pred_val = pred_val.cpu().detach().numpy()
            # exclude examples from training and validation (if any)
            pred_val[X.nonzero()] = -np.inf
            ndcg.append(NDCG_binary_at_k_batch(pred_val, batch_val[1]))
            recall.append(Recall_at_k_batch(pred_val, batch_val[1]))

        ndcg_dist = np.concatenate(ndcg)
        recall_dist = np.concatenate(recall)
        current_ndcg = ndcg_dist.mean()
        current_recall = recall_dist.mean()
        resulting_dict['model'].append(m)
        resulting_dict['ndcg'].append(current_ndcg)
        resulting_dict['recall'].append(current_recall)
        resulting_dict['noise'].append(noise)
        print(m, 'NDCG: ', current_ndcg, '\t', 'Recall: ', current_recall)
        del model

  2%|▏         | 1/50 [00:22<18:41, 22.90s/it]

best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.3599860801663479 	 Recall:  0.5913827374137361


  4%|▍         | 2/50 [00:46<18:33, 23.20s/it]

best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3869960989637224 	 Recall:  0.6260463131454891


  6%|▌         | 3/50 [01:34<23:54, 30.51s/it]

best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.3556332987902591 	 Recall:  0.584346510422563


  8%|▊         | 4/50 [02:19<26:40, 34.80s/it]

best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.3871900650448047 	 Recall:  0.6202633598928076


 10%|█         | 5/50 [02:26<20:00, 26.68s/it]

best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt NDCG:  0.3619242533990528 	 Recall:  0.5955743713151105


 12%|█▏        | 6/50 [02:34<15:18, 20.88s/it]

best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt NDCG:  0.3923779405684078 	 Recall:  0.6286353124365546


 14%|█▍        | 7/50 [02:58<15:36, 21.78s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.35879325821807323 	 Recall:  0.5894073144022108


 16%|█▌        | 8/50 [03:20<15:23, 22.00s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3941336728180382 	 Recall:  0.6283833780897363


 18%|█▊        | 9/50 [03:43<15:17, 22.38s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3605106032479916 	 Recall:  0.5908602258858912


 20%|██        | 10/50 [04:07<15:12, 22.82s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True.pt NDCG:  0.38840191318862777 	 Recall:  0.6276457781768106


 22%|██▏       | 11/50 [04:42<17:04, 26.28s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3913519364640719 	 Recall:  0.6287199617554532


 24%|██▍       | 12/50 [05:18<18:29, 29.19s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_True.pt NDCG:  0.3918867038125033 	 Recall:  0.6288448197424459


 26%|██▌       | 13/50 [06:06<21:36, 35.03s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.3596865985814431 	 Recall:  0.5885949688111318


 28%|██▊       | 14/50 [06:56<23:44, 39.58s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.3898459361938759 	 Recall:  0.6283403871704691


 30%|███       | 15/50 [07:45<24:35, 42.15s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_False.pt NDCG:  0.36035243063928446 	 Recall:  0.5896698527121945


 32%|███▏      | 16/50 [08:33<24:59, 44.10s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True.pt NDCG:  0.3928141160977155 	 Recall:  0.6306737619778007


 34%|███▍      | 17/50 [09:24<25:23, 46.18s/it]

best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3924545336360928 	 Recall:  0.6294564018133613


 36%|███▌      | 18/50 [10:18<25:45, 48.31s/it]

best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True.pt NDCG:  0.39866222776061483 	 Recall:  0.6353191968540188


 38%|███▊      | 19/50 [10:39<20:51, 40.36s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.36170831143295373 	 Recall:  0.5920325614287665


 40%|████      | 20/50 [11:01<17:25, 34.84s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.39265774620648664 	 Recall:  0.6289027201314517


 42%|████▏     | 21/50 [11:25<15:10, 31.39s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_False.pt NDCG:  0.36047320184691256 	 Recall:  0.5931143705545985


 44%|████▍     | 22/50 [11:48<13:31, 28.99s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True.pt NDCG:  0.3898550466243452 	 Recall:  0.6279303140893364


 46%|████▌     | 23/50 [12:18<13:13, 29.38s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_False.pt NDCG:  0.3621371269608224 	 Recall:  0.5914337166871393


 48%|████▊     | 24/50 [12:49<12:54, 29.80s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_True.pt NDCG:  0.3898612432063945 	 Recall:  0.6270494068091119


 50%|█████     | 25/50 [13:18<12:21, 29.66s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_False.pt NDCG:  0.3591312197403387 	 Recall:  0.5871543506232507


 52%|█████▏    | 26/50 [13:45<11:30, 28.76s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_True.pt NDCG:  0.39080217299606784 	 Recall:  0.6292453610457914


 54%|█████▍    | 27/50 [14:18<11:31, 30.07s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_False.pt NDCG:  0.3600667941508042 	 Recall:  0.5913560818731581


 56%|█████▌    | 28/50 [14:52<11:22, 31.03s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True.pt NDCG:  0.3908040670410487 	 Recall:  0.6262386175731842


 58%|█████▊    | 29/50 [15:24<11:00, 31.43s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_False.pt NDCG:  0.36010092377790864 	 Recall:  0.5890684228548533


 60%|██████    | 30/50 [15:54<10:21, 31.07s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True.pt NDCG:  0.39037588122232236 	 Recall:  0.6278359611884752


 62%|██████▏   | 31/50 [16:35<10:45, 33.97s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_False.pt NDCG:  0.3624237380773636 	 Recall:  0.5912327783371122


 64%|██████▍   | 32/50 [17:17<10:54, 36.34s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_True.pt NDCG:  0.3900237260555439 	 Recall:  0.6291605778062179


 66%|██████▌   | 33/50 [17:56<10:34, 37.35s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_False.pt NDCG:  0.36026106882678843 	 Recall:  0.5897391586414781


 68%|██████▊   | 34/50 [18:37<10:15, 38.46s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_True.pt NDCG:  0.35309516260490253 	 Recall:  0.5746678078241246


 70%|███████   | 35/50 [19:10<09:12, 36.81s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_False.pt NDCG:  0.35979460033857474 	 Recall:  0.5909036847654905


 72%|███████▏  | 36/50 [19:45<08:26, 36.21s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3908701160887269 	 Recall:  0.6273082869828766


 74%|███████▍  | 37/50 [20:19<07:40, 35.46s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_False.pt NDCG:  0.36002180894910496 	 Recall:  0.5879555332388788


 76%|███████▌  | 38/50 [20:53<07:01, 35.12s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True.pt NDCG:  0.39206050872789583 	 Recall:  0.6264729359510488


 78%|███████▊  | 39/50 [21:37<06:53, 37.58s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.3605611358313147 	 Recall:  0.5926063677003554


 80%|████████  | 40/50 [22:20<06:33, 39.37s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.39164801091899065 	 Recall:  0.62714971770978


 82%|████████▏ | 41/50 [23:05<06:08, 40.94s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_False.pt NDCG:  0.3589631122757636 	 Recall:  0.592109615272958


 84%|████████▍ | 42/50 [23:50<05:38, 42.34s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_True.pt NDCG:  0.3922028186841143 	 Recall:  0.6267401614228858


 86%|████████▌ | 43/50 [24:38<05:08, 44.08s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_False.pt NDCG:  0.36096035802848503 	 Recall:  0.5917265339403359


 88%|████████▊ | 44/50 [25:30<04:38, 46.46s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_True.pt NDCG:  0.3906878219166558 	 Recall:  0.6297148680824918


 90%|█████████ | 45/50 [26:17<03:52, 46.60s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3618596966775307 	 Recall:  0.5912938704494618


 92%|█████████▏| 46/50 [27:04<03:06, 46.61s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True.pt NDCG:  0.3939162894991049 	 Recall:  0.6315151084976999


 94%|█████████▍| 47/50 [28:16<02:43, 54.36s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_False.pt NDCG:  0.36261712896943066 	 Recall:  0.5934934622207817


 96%|█████████▌| 48/50 [29:29<01:59, 59.94s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True.pt NDCG:  0.39473817871855926 	 Recall:  0.6320618573761625


 98%|█████████▊| 49/50 [30:43<01:04, 64.15s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_False.pt NDCG:  0.35894335180999404 	 Recall:  0.5892748618434689


100%|██████████| 50/50 [31:58<00:00, 38.37s/it]
  0%|          | 0/50 [00:00<?, ?it/s]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True.pt NDCG:  0.39334609972000656 	 Recall:  0.62731210396978


  2%|▏         | 1/50 [00:22<18:24, 22.54s/it]

best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.39675711317791784 	 Recall:  0.6282092046507072


  4%|▍         | 2/50 [00:44<17:54, 22.38s/it]

best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.4320197027484642 	 Recall:  0.6639729173601653


  6%|▌         | 3/50 [01:29<22:49, 29.14s/it]

best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.38992758360543983 	 Recall:  0.6185308342589493


  8%|▊         | 4/50 [02:14<26:04, 34.01s/it]

best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.4259487427528229 	 Recall:  0.6564985219294623


 10%|█         | 5/50 [02:21<19:22, 25.84s/it]

best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt NDCG:  0.39716638389173303 	 Recall:  0.6285131186203207


 12%|█▏        | 6/50 [02:28<14:45, 20.12s/it]

best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt NDCG:  0.43223991156491354 	 Recall:  0.6613630759783731


 14%|█▍        | 7/50 [02:49<14:44, 20.57s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.3946547290584871 	 Recall:  0.6256582981588972


 16%|█▌        | 8/50 [03:12<14:44, 21.05s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.43228676690172263 	 Recall:  0.6638881900904336


 18%|█▊        | 9/50 [03:34<14:33, 21.32s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3947283115488491 	 Recall:  0.6255529804913874


 20%|██        | 10/50 [03:54<14:03, 21.10s/it]

best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True.pt NDCG:  0.43245409754443087 	 Recall:  0.6646831408889161


 22%|██▏       | 11/50 [04:29<16:24, 25.25s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_True.pt NDCG:  0.4328045726658963 	 Recall:  0.6615015550440665


 24%|██▍       | 12/50 [05:04<17:49, 28.15s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_True_anneal_True.pt NDCG:  0.4325158172193171 	 Recall:  0.6622387232585234


 26%|██▌       | 13/50 [05:52<20:59, 34.05s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.39382461108769085 	 Recall:  0.6236045177792566


 28%|██▊       | 14/50 [06:41<23:06, 38.52s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.43182996277755803 	 Recall:  0.6637024662945274


 30%|███       | 15/50 [07:31<24:32, 42.07s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_False.pt NDCG:  0.39383953869150795 	 Recall:  0.6236438501675707


 32%|███▏      | 16/50 [08:15<24:13, 42.74s/it]

best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True.pt NDCG:  0.43292576692225787 	 Recall:  0.6636614790184263


 34%|███▍      | 17/50 [09:05<24:36, 44.73s/it]

best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_True.pt NDCG:  0.43202368971913774 	 Recall:  0.6627317034824056


 36%|███▌      | 18/50 [09:54<24:33, 46.05s/it]

best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True.pt NDCG:  0.431450756640749 	 Recall:  0.6630818135419878


 38%|███▊      | 19/50 [10:16<20:02, 38.80s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_False.pt NDCG:  0.39465373238242174 	 Recall:  0.6240018366045105


 40%|████      | 20/50 [10:35<16:27, 32.90s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True.pt NDCG:  0.4328944051494731 	 Recall:  0.6630478076655533


 42%|████▏     | 21/50 [10:55<14:02, 29.04s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3942925126324004 	 Recall:  0.6247742734367686


 44%|████▍     | 22/50 [11:16<12:24, 26.58s/it]

best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True.pt NDCG:  0.4335142018978614 	 Recall:  0.6646206260489538


 46%|████▌     | 23/50 [11:42<11:51, 26.35s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_False.pt NDCG:  0.39365518318366255 	 Recall:  0.6238886548207991


 48%|████▊     | 24/50 [12:08<11:22, 26.26s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_False_anneal_True.pt NDCG:  0.43236103284695937 	 Recall:  0.663000695818898


 50%|█████     | 25/50 [12:33<10:50, 26.02s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_False.pt NDCG:  0.393905761746157 	 Recall:  0.6227148824735784


 52%|█████▏    | 26/50 [12:59<10:26, 26.10s/it]

best_model_Multi_our_VAE_K_1_N_2_learnreverse_True_anneal_True.pt NDCG:  0.4324113587347856 	 Recall:  0.6636629934597966


 54%|█████▍    | 27/50 [13:32<10:45, 28.06s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_False.pt NDCG:  0.39464665494824275 	 Recall:  0.6258909873033474


 56%|█████▌    | 28/50 [14:06<10:58, 29.95s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True.pt NDCG:  0.43326115555105943 	 Recall:  0.6637666364663711


 58%|█████▊    | 29/50 [14:41<10:55, 31.21s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_False.pt NDCG:  0.39462169394171054 	 Recall:  0.624403719428512


 60%|██████    | 30/50 [15:11<10:22, 31.10s/it]

best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True.pt NDCG:  0.4329023493760451 	 Recall:  0.6639837995103338


 62%|██████▏   | 31/50 [15:54<10:54, 34.46s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_False.pt NDCG:  0.3945503265071165 	 Recall:  0.6248159238741994


 64%|██████▍   | 32/50 [16:39<11:20, 37.82s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_False_anneal_True.pt NDCG:  0.43269630202915166 	 Recall:  0.66395870049706


 66%|██████▌   | 33/50 [17:25<11:23, 40.23s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_False.pt NDCG:  0.39429636004612667 	 Recall:  0.6246460009604804


 68%|██████▊   | 34/50 [18:10<11:02, 41.43s/it]

best_model_Multi_our_VAE_K_1_N_5_learnreverse_True_anneal_True.pt NDCG:  0.3559153380951852 	 Recall:  0.5781656100894702


 70%|███████   | 35/50 [18:45<09:54, 39.63s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_False.pt NDCG:  0.3942202946477678 	 Recall:  0.6246435925939003


 72%|███████▏  | 36/50 [19:19<08:49, 37.81s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True.pt NDCG:  0.43288761373089774 	 Recall:  0.6632895886751105


 74%|███████▍  | 37/50 [19:53<07:56, 36.68s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_False.pt NDCG:  0.39431061501902753 	 Recall:  0.6247187095380861


 76%|███████▌  | 38/50 [20:26<07:09, 35.82s/it]

best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True.pt NDCG:  0.43289151464250564 	 Recall:  0.6626528310821381


 78%|███████▊  | 39/50 [21:12<07:06, 38.81s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_False.pt NDCG:  0.395547052235505 	 Recall:  0.6263853735124593


 80%|████████  | 40/50 [21:59<06:53, 41.34s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True.pt NDCG:  0.4328883350539547 	 Recall:  0.6637168059647561


 82%|████████▏ | 41/50 [22:44<06:20, 42.30s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_False.pt NDCG:  0.39582178938050677 	 Recall:  0.6267767628830229


 84%|████████▍ | 42/50 [23:29<05:44, 43.07s/it]

best_model_Multi_our_VAE_K_2_N_2_learnreverse_True_anneal_True.pt NDCG:  0.43207361635955494 	 Recall:  0.6616242015818475


 86%|████████▌ | 43/50 [24:16<05:09, 44.25s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_False.pt NDCG:  0.39468654404632586 	 Recall:  0.6237404600852448


 88%|████████▊ | 44/50 [25:04<04:32, 45.49s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_False_anneal_True.pt NDCG:  0.4320835099533606 	 Recall:  0.6633997587474327


 90%|█████████ | 45/50 [25:50<03:48, 45.68s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3948872528611854 	 Recall:  0.6240550797583879


 92%|█████████▏| 46/50 [26:39<03:06, 46.59s/it]

best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True.pt NDCG:  0.4324688935500537 	 Recall:  0.6636623519483595


 94%|█████████▍| 47/50 [27:52<02:43, 54.38s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_False.pt NDCG:  0.39485367862430837 	 Recall:  0.6255770918723464


 96%|█████████▌| 48/50 [29:05<02:00, 60.03s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True.pt NDCG:  0.4325399914134912 	 Recall:  0.6629434271066635


 98%|█████████▊| 49/50 [30:19<01:04, 64.30s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_False.pt NDCG:  0.3947527895847348 	 Recall:  0.6248612436214498


100%|██████████| 50/50 [31:33<00:00, 37.87s/it]

best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True.pt NDCG:  0.4341622631175437 	 Recall:  0.6630982773745904





In [7]:
pd.set_option('display.max_colwidth', 100)

In [8]:
df = pd.DataFrame(resulting_dict)
df.to_csv('./eval_results.csv', index=False)
df

Unnamed: 0,model,ndcg,recall,noise
0,best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_False.pt,0.359986,0.591383,1.0
1,best_model_MultiHoffmanVAE_K_1_N_1_learnreverse_False_anneal_True.pt,0.386996,0.626046,1.0
2,best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_False.pt,0.355633,0.584347,1.0
3,best_model_MultiHoffmanVAE_K_2_N_2_learnreverse_False_anneal_True.pt,0.387190,0.620263,1.0
4,best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt,0.361924,0.595574,1.0
5,best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt,0.392378,0.628635,1.0
6,best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_False.pt,0.358793,0.589407,1.0
7,best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt,0.394134,0.628383,1.0
8,best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_False.pt,0.360511,0.590860,1.0
9,best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_True_anneal_True.pt,0.388402,0.627646,1.0


In [14]:
df[df.noise == 1.].sort_values('ndcg', ascending=False).head(9)

Unnamed: 0,model,ndcg,recall,noise
17,best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_True_anneal_True.pt,0.398662,0.635319,1.0
47,best_model_Multi_our_VAE_K_5_N_1_learnreverse_False_anneal_True.pt,0.394738,0.632062,1.0
7,best_model_Multi_ourHoffman_VAE_K_1_N_1_learnreverse_False_anneal_True.pt,0.394134,0.628383,1.0
45,best_model_Multi_our_VAE_K_3_N_1_learnreverse_True_anneal_True.pt,0.393916,0.631515,1.0
49,best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True.pt,0.393346,0.627312,1.0
15,best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True.pt,0.392814,0.630674,1.0
19,best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True.pt,0.392658,0.628903,1.0
16,best_model_Multi_ourHoffman_VAE_K_3_N_1_learnreverse_False_anneal_True.pt,0.392455,0.629456,1.0
5,best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt,0.392378,0.628635,1.0


In [12]:
df[df.noise == 0.].sort_values('ndcg', ascending=False).head(19)

Unnamed: 0,model,ndcg,recall,noise
99,best_model_Multi_our_VAE_K_5_N_1_learnreverse_True_anneal_True.pt,0.434162,0.663098,0.0
71,best_model_Multi_our_VAE_K_1_N_1_learnreverse_True_anneal_True.pt,0.433514,0.664621,0.0
77,best_model_Multi_our_VAE_K_1_N_3_learnreverse_False_anneal_True.pt,0.433261,0.663767,0.0
65,best_model_Multi_ourHoffman_VAE_K_2_N_2_learnreverse_True_anneal_True.pt,0.432926,0.663661,0.0
79,best_model_Multi_our_VAE_K_1_N_3_learnreverse_True_anneal_True.pt,0.432902,0.663984,0.0
69,best_model_Multi_our_VAE_K_1_N_1_learnreverse_False_anneal_True.pt,0.432894,0.663048,0.0
87,best_model_Multi_our_VAE_K_2_N_1_learnreverse_True_anneal_True.pt,0.432892,0.662653,0.0
89,best_model_Multi_our_VAE_K_2_N_2_learnreverse_False_anneal_True.pt,0.432888,0.663717,0.0
85,best_model_Multi_our_VAE_K_2_N_1_learnreverse_False_anneal_True.pt,0.432888,0.66329,0.0
60,best_model_Multi_ourHoffman_VAE_K_2_N_1_learnreverse_False_anneal_True.pt,0.432805,0.661502,0.0


# Parameters tuning

In [4]:
model = load_model('./models/best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_False.pt', 'cpu')
# 'best_model_MultiVAE_K_None_N_None_learnreverse_False_anneal_True.pt'

In [5]:
for p in model.parameters():
    p.requires_grad_(False)
model.eval()
fixed_decoder = model.decoder
del model

layers = [200, 600, dataset.n_items]

model = MultiVAE(layers, args=args)
model.decoder = fixed_decoder

Sequential(
  (0): Linear(in_features=20108, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=400, bias=True)
)
Sequential(
  (0): Linear(in_features=200, out_features=600, bias=True)
  (1): Tanh()
  (2): Linear(in_features=600, out_features=20108, bias=True)
)


In [9]:
train_model(model, dataset, args)

  0%|          | 0/200 [00:00<?, ?it/s]

666.53955
592.1258


KeyboardInterrupt: 