In [1]:
import argparse
from collections import OrderedDict
import copy
from multiprocessing import Process,Manager
import numpy as np
import pandas as pd
from scipy import sparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
import time
from tqdm import tqdm

import models
import data
import metric

In [2]:
# import importlib
# importlib.reload(metric)

In [3]:
# Set Configs

In [4]:
##  Set the random seed manually for reproductibility.
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7f301d906170>

In [5]:
# device= torch.device("cuda")
device = torch.device("cpu")

In [6]:
# Load Data
loader = data.DataLoader('ml-20m')

n_items = loader.load_n_items()
train_data = loader.load_data('train')
vad_data_tr, vad_data_te = loader.load_data('validation')
test_data_tr, test_data_te = loader.load_data('test')

N = train_data.shape[0]
idxlist = list(range(N))

print("# of items:{}".format(n_items))

# of items:20101


In [7]:
# Build the model

p_dims = [200, 600, n_items]
model = models.MultiVAE(p_dims).to(device)

print(f"Model Structure:{model}\n")
# for name, param in model.named_parameters():
#     print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.00)
criterion = models.loss_function

Model Structure:MultiVAE(
  (q_layers): ModuleList(
    (0): Linear(in_features=20101, out_features=600, bias=True)
    (1): Linear(in_features=600, out_features=400, bias=True)
  )
  (p_layers): ModuleList(
    (0): Linear(in_features=200, out_features=600, bias=True)
    (1): Linear(in_features=600, out_features=20101, bias=True)
  )
  (drop): Dropout(p=0.5, inplace=False)
)



In [8]:
# TensorboardX Writer
writer= SummaryWriter()

In [9]:
# Train

In [10]:
BATCH_SIZE = 500
TOTAL_ANNEAL_STEPS = 200000
ANNEAL_CAP = 0.2
LOG_INTERVAL = 100
# EPOCHS = 100
EPOCHS = 200
SAVE_PATH = 'model.pt'

In [11]:
def sparse2torch_sparse(data):
    """
    Convert scipy sparse matrix to torch sparse tensor with L2 Normalization
    This is much faster than naive use of torch.FloatTensor(data.toarray())
    https://discuss.pytorch.org/t/sparse-tensor-use-cases/22047/2
    """
    samples = data.shape[0]
    features = data.shape[1]
    coo_data = data.tocoo()
    indices = torch.LongTensor([coo_data.row, coo_data.col])
    row_norms_inv = 1 / np.sqrt(data.sum(1))
    row2val = {i : row_norms_inv[i].item() for i in range(samples)}
    values = np.array([row2val[r] for r in coo_data.row])
    t = torch.sparse.FloatTensor(indices, torch.from_numpy(values).float(), [samples, features])
    return t

In [12]:
def naive_sparse2tensor(data):
    return torch.FloatTensor(data.toarray())

In [13]:
def train():
    # Turn on training mode
    model.train()
    train_loss = 0.0
    start_time = time.time()
    global update_count

    np.random.shuffle(idxlist)
    
    for batch_idx, start_idx in enumerate(range(0, N, BATCH_SIZE)):
        end_idx = min(start_idx + BATCH_SIZE, N)
        data = train_data[idxlist[start_idx:end_idx]]
        data = naive_sparse2tensor(data).to(device)

        if TOTAL_ANNEAL_STEPS > 0:
            anneal = min(ANNEAL_CAP, 
                            1. * update_count / TOTAL_ANNEAL_STEPS)
        else:
            anneal = ANNEAL_CAP

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        
        loss = criterion(recon_batch, data, mu, logvar, anneal)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        update_count += 1

        if batch_idx % LOG_INTERVAL == 0 and batch_idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:4d}/{:4d} batches | ms/batch {:4.2f} | '
                    'loss {:4.2f}'.format(
                        epoch, batch_idx, len(range(0, N, BATCH_SIZE)),
                        elapsed * 1000 / LOG_INTERVAL,
                        train_loss / LOG_INTERVAL))
            
            # Log loss to tensorboard
            n_iter = (epoch - 1) * len(range(0, N, BATCH_SIZE)) + batch_idx
            writer.add_scalars('data/loss', {'train': train_loss / LOG_INTERVAL}, n_iter)

            start_time = time.time()
            train_loss = 0.0

In [14]:
def evaluate(data_tr, data_te):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    n1_list = []
    n100_list = []
    r20_list = []
    r50_list = []
    
    with torch.no_grad():
        for start_idx in range(0, e_N, BATCH_SIZE):
            end_idx = min(start_idx + BATCH_SIZE, N)
            data = data_tr[e_idxlist[start_idx:end_idx]]
            heldout_data = data_te[e_idxlist[start_idx:end_idx]]
    
            # cno : avoid users who have no clicks in heldout_data
            u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
            data = data[u_idxlist_wo_any_iteracts]
            heldout_data = heldout_data[u_idxlist_wo_any_iteracts]
            
            data_tensor = naive_sparse2tensor(data).to(device)

            if TOTAL_ANNEAL_STEPS > 0:
                anneal = min(ANNEAL_CAP, 
                               1. * update_count / TOTAL_ANNEAL_STEPS)
            else:
                anneal = ANNEAL_CAP

            recon_batch, mu, logvar = model(data_tensor)

            loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
            total_loss += loss.item()

            # Exclude examples from training set
            recon_batch = recon_batch.cpu().numpy()
            recon_batch[data.nonzero()] = -np.inf

            n1 = metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 1)
            n100 = metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 100)
            r20 = metric.Recall_at_k_batch(recon_batch, heldout_data, 20)
            r50 = metric.Recall_at_k_batch(recon_batch, heldout_data, 50)

            n1_list.append(n1)
            n100_list.append(n100)
            r20_list.append(r20)
            r50_list.append(r50)
 
    total_loss /= len(range(0, e_N, BATCH_SIZE))
    n1_list = np.concatenate(n1_list)
    n100_list = np.concatenate(n100_list)
    r20_list = np.concatenate(r20_list)
    r50_list = np.concatenate(r50_list)

    return total_loss, n1_list, n100_list, r20_list, r50_list

In [13]:
best_n100 = -np.inf
update_count = 0

# At any point you can hit Ctrl + C to break out of training early.
try:
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train()
        val_loss, n100, r20, r50 = evaluate(vad_data_tr, vad_data_te)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:4.2f}s | valid loss {:4.2f} | '
                'n100 {:5.3f} | r20 {:5.3f} | r50 {:5.3f}'.format(
                    epoch, time.time() - epoch_start_time, val_loss,
                    n100, r20, r50))
        print('-' * 89)

        n_iter = epoch * len(range(0, N, BATCH_SIZE))
        writer.add_scalars('data/loss', {'valid': val_loss}, n_iter)
        writer.add_scalar('data/n100', n100, n_iter)
        writer.add_scalar('data/r20', r20, n_iter)
        writer.add_scalar('data/r50', r50, n_iter)

        # Save the model if the n100 is the best we've seen so far.
        if n100 > best_n100:
            with open(SAVE_PATH, 'wb') as f:
                torch.save(model, f)
            best_n100 = n100

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

print(update_count)

| epoch   1 |  100/ 233 batches | ms/batch 134.92 | loss 572.41
-----------------------------------------------------------------------------------------
Exiting from training early


In [15]:
# Load the best saved model.
MODEL_PATH = SAVE_PATH
with open(SAVE_PATH, 'rb') as f:
    model = torch.load(f).cpu()

In [35]:
# Run on test data.
# update_count = 0
test_loss, n1, n100, r20, r50 = evaluate(test_data_tr, test_data_te)
print('=' * 89)
print('| End of training | test loss {:4.2f} | n1 {:4.3f}({:4.3f}) | n100 {:4.3f}({:4.3f}) | r20 {:4.3f}({:4.3f}) | '
        'r50 {:4.3f}'.format(test_loss, np.mean(n1), np.std(n1), np.mean(n100), np.std(n100), np.mean(r20), np.std(r20), np.mean(r50), np.std(r50)))
print('=' * 89)

| End of training | test loss 366.42 | n1 0.369 | n100 0.428 | r20 0.400 | r50 0.537


In [None]:
# index items by using weights in the encoding of VAE-model

In [16]:
stdict = model.state_dict()
print(stdict.keys())

odict_keys(['q_layers.0.weight', 'q_layers.0.bias', 'q_layers.1.weight', 'q_layers.1.bias', 'p_layers.0.weight', 'p_layers.0.bias', 'p_layers.1.weight', 'p_layers.1.bias'])


In [17]:
P0 = stdict['p_layers.0.weight']
p0_bias = stdict['p_layers.0.bias']
P1 = stdict['p_layers.1.weight']
p1_bias = stdict['p_layers.1.bias']
print(P0.shape)
print(p0_bias.shape)
print(P1.shape)
print(p1_bias.shape)

torch.Size([600, 200])
torch.Size([600])
torch.Size([20101, 600])
torch.Size([20101])


In [18]:
# B(tanh(Az+b))+b' = Bz'+b' = ([B,b'](z',1))
P1_dash = torch.column_stack((P1,p1_bias))
print(P1_dash.shape)

torch.Size([20101, 601])


In [19]:
# https://qiita.com/saliton/items/3650e8518d8bf0684332
import scann
d = P1_dash.shape[1]
scann_brute = scann.scann_ops_pybind.builder(P1_dash.numpy(), d, "dot_product").score_brute_force().build()

2024-03-28 18:43:43.980406: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
scann_searcher = scann.scann_ops_pybind.builder(P1_dash.numpy(), d, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

# indexing by faiss

In [20]:
import faiss

# build a flat (CPU) index
d = P1_dash.shape[1]
index_flat = faiss.IndexFlatIP(d)

# make it into a gpu index
# gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)
# # make it into a gpu index(multi GPUs)
# gpu_index = faiss.index_cpu_to_all_gpus(  # build the index
#     cpu_index
# )

# indexing
index_flat.add(P1_dash)  
print(index_flat.ntotal)

20101


In [28]:
nlist = 100
quantizer = faiss.IndexFlatIP(d)
index_ivf = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
index_ivf.train(P1_dash)
index_ivf.add(P1_dash)

index_ivf.nprobe = 3

In [41]:
# update_count = 

# tauをアイテム数で決める(gumbel-sharpをアイテム数で決めてみる)

In [84]:
# beta = np.log(1/n_items)
beta = 1
# tau = 0.1
tau = 1/np.log(np.log(n_items))
print(tau)
print(1/tau)

0.43603469522020893
2.293395481969557


In [85]:
def sampling_ranking(scores, heldout_data, seed, n1_list_per_sampling, n100_list_per_sampling, r20_list_per_sampling, r50_list_per_sampling):
    # Add Gumbel samples
    np.random.seed(seed=seed)
    gumbel_sampled_scores = scores + np.vectorize(gumbel_inverse)(np.random.uniform(size=scores.shape))
    # Exclude examples from training set
    # gumbel_sampled_scores[data.nonzero()] = -np.inf

    n1_list_per_sampling.append(metric.NDCG_binary_at_k_batch(gumbel_sampled_scores, heldout_data, 1))
    n100_list_per_sampling.append(metric.NDCG_binary_at_k_batch(gumbel_sampled_scores, heldout_data, 100))
    r20_list_per_sampling.append(metric.Recall_at_k_batch(gumbel_sampled_scores, heldout_data, 20))
    r50_list_per_sampling.append(metric.Recall_at_k_batch(gumbel_sampled_scores, heldout_data, 50))

In [87]:
def evaluate_expectation2(data_tr, data_te, n_sampling=1):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    n1_list = []
    n100_list = []
    r20_list = []
    r50_list = []
    n1_list_per_sampling = []
    n100_list_per_sampling = []
    r20_list_per_sampling = []
    r50_list_per_sampling = []
    
    manager = Manager()
    dummy = manager.dict()
    
    with torch.no_grad():
        with tqdm(range(0, e_N, BATCH_SIZE)) as pbar:
        # for start_idx in tqdm(range(0, e_N, BATCH_SIZE)):
            for start_idx in pbar:
                pbar.set_description("[test]")
                  
                end_idx = min(start_idx + BATCH_SIZE, N)
                data = data_tr[e_idxlist[start_idx:end_idx]]
                heldout_data = data_te[e_idxlist[start_idx:end_idx]]

                u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
                data = data[u_idxlist_wo_any_iteracts]
                heldout_data = heldout_data[u_idxlist_wo_any_iteracts]

                data_tensor = naive_sparse2tensor(data).to(device)

                if TOTAL_ANNEAL_STEPS > 0:
                    anneal = min(ANNEAL_CAP, 
                                   1. * update_count / TOTAL_ANNEAL_STEPS)
                else:
                    anneal = ANNEAL_CAP

                recon_batch, mu, logvar = model(data_tensor)

                loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
                total_loss += loss.item()
                # pbar.set_description(OrderedDict(total_loss=total_loss))

                # print(torch.mean(recon_batch,1))
                # print(torch.transpose(recon_batch,0,1).size())
                # print(torch.mean(torch.transpose(torch.div(torch.transpose(recon_batch,0,1), torch.mean(recon_batch,1)),0,1),1))
                # print(torch.div(torch.transpose(torch.div(torch.transpose(recon_batch,0,1), torch.mean(recon_batch,1)),0,1),tau).size())
                # print(torch.mean(torch.div(torch.transpose(torch.div(torch.transpose(recon_batch,0,1), torch.mean(recon_batch,1)),0,1),tau),1))

                recon_batch = F.log_softmax(torch.div(recon_batch,tau), 1)
                # recon_batch = F.log_softmax(torch.div(torch.transpose(torch.div(torch.transpose(recon_batch,0,1), torch.mean(recon_batch,1)),0,1),tau), 1)
                recon_batch = recon_batch.cpu().numpy()
                # recon_batch[data.nonzero()] = -np.inf
                
                with Manager() as manager:
                    # d = manager.dict()
                    # l = manager.list()
                    n1_list_per_sampling = manager.list()
                    n100_list_per_sampling = manager.list()
                    r20_list_per_sampling = manager.list()
                    r50_list_per_sampling = manager.list()
                    p_list = []
                    for l in range(n_sampling):
                        p = Process(target=sampling_ranking, args=(recon_batch,heldout_data,l,n1_list_per_sampling,n100_list_per_sampling,r20_list_per_sampling,r50_list_per_sampling))
                        p.start()
                        p_list.append(p)
                        if len(p_list) % 4 == 0:
                            for p in p_list:
                                p.join()
                            p_list = []
                    
                    for p in p_list:
                        p.join()
                    
                    # print(n1_list_per_sampling)
                    # print(n100_list_per_sampling)

                    n1_list.append(np.concatenate(n1_list_per_sampling))
                    n100_list.append(np.concatenate(n100_list_per_sampling))
                    r20_list.append(np.concatenate(r20_list_per_sampling))
                    r50_list.append(np.concatenate(r50_list_per_sampling))
    
    total_loss /= len(range(0, e_N, BATCH_SIZE))
    n1_list = np.concatenate(n1_list)
    n100_list = np.concatenate(n100_list)
    r20_list = np.concatenate(r20_list)
    r50_list = np.concatenate(r50_list)

    return total_loss, n1_list, n100_list, r20_list, r50_list

In [88]:
# Run on test data.
test_loss2, n1_list2, n100_list2, r20_list2, r50_list2 = evaluate_expectation2(test_data_tr, test_data_te, n_sampling=20)
print('=' * 89)
print('| End of training | test loss {:4.2f} | n1 {:4.3f}({:4.3f}) | n100 {:4.3f}({:4.3f}) | r20 {:4.3f}({:4.3f}) | '
        'r50 {:4.3f}({:4.3f})'.format(test_loss2, np.mean(n1_list2), np.std(n1_list2)/np.sqrt(len(n1_list2)), np.mean(n100_list2), np.std(n100_list2)/np.sqrt(len(n100_list2)), np.mean(r20_list2), np.std(r20_list2)/np.sqrt(len(r20_list2)), np.mean(r50_list2), np.std(r50_list2)/np.sqrt(len(r50_list2))))
print('=' * 89)

[test]: 100%|██████████| 20/20 [2:08:43<00:00, 386.16s/it]t]

| End of training | test loss 366.42 | n1 0.068(0.001) | n100 0.247(0.000) | r20 0.201(0.001) | r50 0.386(0.001)





# stochastic VAE

In [28]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def evaluate_stochasticVAE(data_tr, data_te, n_sampling=1):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    n1_list = []
    n100_list = []
    r20_list = []
    r50_list = []
    n1_list_per_sampling = []
    n100_list_per_sampling = []
    r20_list_per_sampling = []
    r50_list_per_sampling = []
    
    with torch.no_grad():
        with tqdm(range(0, e_N, BATCH_SIZE)) as pbar:
        # for start_idx in tqdm(range(0, e_N, BATCH_SIZE)):
            for start_idx in pbar:
                pbar.set_description("[test]")
                  
                end_idx = min(start_idx + BATCH_SIZE, N)
                data = data_tr[e_idxlist[start_idx:end_idx]]
                heldout_data = data_te[e_idxlist[start_idx:end_idx]]

                u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
                data = data[u_idxlist_wo_any_iteracts]
                heldout_data = heldout_data[u_idxlist_wo_any_iteracts]

                data_tensor = naive_sparse2tensor(data).to(device)

                if TOTAL_ANNEAL_STEPS > 0:
                    anneal = min(ANNEAL_CAP, 
                                   1. * update_count / TOTAL_ANNEAL_STEPS)
                else:
                    anneal = ANNEAL_CAP

                # recon_batch, mu, logvar = model(data_tensor)
                mu, logvar = model.encode(data_tensor)

                # loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
                # total_loss += loss.item()
                # recon_batch = recon_batch.cpu().numpy()
                # recon_batch[data.nonzero()] = -np.inf

                for l in range(n_sampling):
                    z = reparameterize(mu, logvar)
                    # print(z)
                    recon_batch = model.decode(z)
                    recon_batch = recon_batch.cpu().numpy()

                    n1_list_per_sampling.append(metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 1))
                    n100_list_per_sampling.append(metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 100))
                    r20_list_per_sampling.append(metric.Recall_at_k_batch(recon_batch, heldout_data, 20))
                    r50_list_per_sampling.append(metric.Recall_at_k_batch(recon_batch, heldout_data, 50))

                n1_list.append(np.concatenate(n1_list_per_sampling))
                n100_list.append(np.concatenate(n100_list_per_sampling))
                r20_list.append(np.concatenate(r20_list_per_sampling))
                r50_list.append(np.concatenate(r50_list_per_sampling))
    
    total_loss /= len(range(0, e_N, BATCH_SIZE))
    n1_list = np.concatenate(n1_list)
    n100_list = np.concatenate(n100_list)
    r20_list = np.concatenate(r20_list)
    r50_list = np.concatenate(r50_list)

    return total_loss, n1_list, n100_list, r20_list, r50_list

In [25]:
test_loss, n1_list, n100_list, r20_list, r50_list = evaluate_stochasticVAE(test_data_tr, test_data_te, n_sampling=20)
print('=' * 89)
print('| End of training | test loss {:4.2f} | n1 {:4.3f}({:4.3f}) | n100 {:4.3f}({:4.3f}) | r20 {:4.3f}({:4.3f}) | '
        'r50 {:4.3f}({:4.3f})'.format(test_loss, np.mean(n1_list), np.std(n1_list)/np.sqrt(len(n1_list)), np.mean(n100_list), np.std(n100_list)/np.sqrt(len(n100_list)), np.mean(r20_list), np.std(r20_list)/np.sqrt(len(r20_list)), np.mean(r50_list), np.std(r50_list)/np.sqrt(len(r50_list))))
print('=' * 89)

[test]: 100%|██████████| 20/20 [02:07<00:00,  6.37s/it]

| End of training | test loss 0.00 | n1 0.059(0.000) | n100 0.254(0.000) | r20 0.217(0.000) | r50 0.395(0.000)





# evaluate multi-VAE, Gumbel-VAE, Stochastic-VAE

In [40]:
beta = 0.2
# beta = 1
# beta_dash = 0.2

def gumbel_inverse(x):
    return -beta*np.log(-np.log(x))

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def evaluate_stochastic(data_tr, data_te, n_sampling=1):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    
    metrics = {"ndcg@20":[[] for _ in range(n_sampling)],
               "ndcg@100":[[] for _ in range(n_sampling)],
               "recall@20":[[] for _ in range(n_sampling)],
               "recall@50":[[] for _ in range(n_sampling)],
               "precision@20":[[] for _ in range(n_sampling)],
               "precision@50":[[] for _ in range(n_sampling)],
               "hit_rate@20" : [[] for _ in range(n_sampling)],
               "hit_rate@100" : [[] for _ in range(n_sampling)],
               "prediction_time": [[] for _ in range(n_sampling)],
              }
    
    metrics_dic = {
        "multi-VAE":copy.deepcopy(metrics),
        "multi-VAE-Gumbel":copy.deepcopy(metrics),
        # "multi-VAE-Gumbel-low-beta":copy.deepcopy(metrics),
        "multi-VAE-Stochastic":copy.deepcopy(metrics),
        "multi-VAE-Stochastic-Faiss":copy.deepcopy(metrics),
                }
    
    with torch.no_grad():
        with tqdm(range(0, e_N, BATCH_SIZE)) as pbar:
        # for start_idx in tqdm(range(0, e_N, BATCH_SIZE)):
            for start_idx in pbar:
                pbar.set_description("[test]")
                  
                end_idx = min(start_idx + BATCH_SIZE, N)
                data = data_tr[e_idxlist[start_idx:end_idx]]
                heldout_data = data_te[e_idxlist[start_idx:end_idx]]

                u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
                data = data[u_idxlist_wo_any_iteracts]
                heldout_data = heldout_data[u_idxlist_wo_any_iteracts]

                data_tensor = naive_sparse2tensor(data).to(device)

                # if TOTAL_ANNEAL_STEPS > 0:
                #     anneal = min(ANNEAL_CAP, 
                #                    1. * update_count / TOTAL_ANNEAL_STEPS)
                # else:
                #     anneal = ANNEAL_CAP

                # encoding
                start = time.perf_counter()
                mu, logvar = model.encode(data_tensor)
                t_encode = time.perf_counter() - start
                
                # decoding
                start = time.perf_counter()
                recon_batch = model.decode(mu)
                recon_batch_cpu = recon_batch.numpy()
                t_decode = time.perf_counter() - start
                
                # start = time.perf_counter()
                # recon_batch_cpu = recon_batch.cpu()
                # t_to_cpu = time.perf_counter() - start
                recon_batch_cpu[data.nonzero()] = -np.inf

                # loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
                # total_loss += loss.item()
                # recon_batch = recon_batch.cpu().numpy()
                # 

                for l in range(n_sampling):
                    torch.manual_seed(l)
                    
                    # bluring z
                    start = time.perf_counter()
                    z_blurred = reparameterize(mu, logvar)
                    t_blurred = time.perf_counter() - start
                    
                    # Stochastic multi-VAE
                    start = time.perf_counter()
                    recon_batch_blurred = model.decode(z_blurred)
                    # recon_batch_blurred = recon_batch_blurred.cpu()
                    recon_batch_blurred = recon_batch_blurred.numpy()
                    t_decode_blurred = time.perf_counter() - start
                    recon_batch_blurred[data.nonzero()] = -np.inf

                    metrics_dic["multi-VAE-Stochastic"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic"]["prediction_time"][l].append(t_encode + t_blurred + t_decode_blurred)
                    
                    # Stochastic multi-VAE with Faiss NNS
                    start = time.perf_counter()
                    z_dash_blurred = torch.tanh(torch.add(torch.matmul(z_blurred,P0.T),p0_bias))
                    z_dash_blurred_wi_constant = torch.column_stack((z_dash_blurred,torch.ones(z_dash_blurred.shape[0], device=device)))
                    # topk_scores, topk_indexes = gpu_index_flat.search(z_dash_blurred_wi_constant.cpu(), 200)
                    topk_indexes, topk_scores = scann_brute.search_batched(z_dash_blurred_wi_constant, 200)
                    t_nns_topk = time.perf_counter() - start
                    
                    recon_batch_dummy = np.ones((z_blurred.shape[0],n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf

                    metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["prediction_time"][l].append(t_encode + t_blurred + t_nns_topk)
                    
                    # multi-VAE + Gumbel Max Sampling
                    start = time.perf_counter() 
                    # recon_batch_gumbel_sampled = recon_batch + np.vectorize(gumbel_inverse)(np.random.uniform(size=recon_batch.shape))
                    # https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
                    # recon_batch_gumbel_sampled = recon_batch - torch.empty_like(recon_batch, memory_format=torch.legacy_contiguous_format).exponential_().log()
                    recon_batch_gumbel_sampled = recon_batch - beta * (-torch.rand(recon_batch.shape, device=device).log()).log()
                    recon_batch_gumbel_sampled = recon_batch_gumbel_sampled
                    t_gumbel_sampling = time.perf_counter() - start
                    
                    recon_batch_gumbel_sampled = recon_batch_gumbel_sampled.numpy()
                    recon_batch_gumbel_sampled[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE-Gumbel"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Gumbel"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Gumbel"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Gumbel"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Gumbel"]["prediction_time"][l].append(t_encode + t_decode + t_gumbel_sampling)
                    
                    # multi-VAE
                    metrics_dic["multi-VAE"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    metrics_dic["multi-VAE"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_cpu, heldout_data, 100))
                    metrics_dic["multi-VAE"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    metrics_dic["multi-VAE"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_cpu, heldout_data, 50))
                    metrics_dic["multi-VAE"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    metrics_dic["multi-VAE"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_cpu, heldout_data, 50))
                    metrics_dic["multi-VAE"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    metrics_dic["multi-VAE"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_cpu, heldout_data, 100))
                    metrics_dic["multi-VAE"]["prediction_time"][l].append(t_encode + t_decode)
    
    # total_loss /= len(range(0, e_N, BATCH_SIZE))
    
    for method_name, metrics in metrics_dic.items():
        for metric_name, metric_list in metrics.items():
            if metric_name == "prediction_time":
                continue
            for l in range(n_sampling):
                metric_list[l] = np.concatenate(metric_list[l])

    return total_loss, metrics_dic

In [46]:
beta = 0.2
# beta = 1
# beta_dash = 0.2

def gumbel_inverse(x):
    return -beta*np.log(-np.log(x))

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def evaluate_stochastic(data_tr, data_te, n_sampling=1):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    
    metrics = {"ndcg@20":[[] for _ in range(n_sampling)],
               # "ndcg@100":[[] for _ in range(n_sampling)],
               "recall@20":[[] for _ in range(n_sampling)],
               # "recall@50":[[] for _ in range(n_sampling)],
               "precision@20":[[] for _ in range(n_sampling)],
               # "precision@50":[[] for _ in range(n_sampling)],
               "hit_rate@20" : [[] for _ in range(n_sampling)],
               # "hit_rate@100" : [[] for _ in range(n_sampling)],
               "prediction_time": [[] for _ in range(n_sampling)],
              }
    
    metrics_dic = {
        "multi-VAE":copy.deepcopy(metrics),
        "multi-VAE-Faiss":copy.deepcopy(metrics),
        "multi-VAE-Faiss-IVF":copy.deepcopy(metrics),
        # "multi-VAE-ScaNN":copy.deepcopy(metrics),
        "multi-VAE-Gumbel":copy.deepcopy(metrics),
        # "multi-VAE-Gumbel-low-beta":copy.deepcopy(metrics),
        "multi-VAE-Stochastic":copy.deepcopy(metrics),
        "multi-VAE-Stochastic-Faiss":copy.deepcopy(metrics),
        "multi-VAE-Stochastic-Faiss-IVF":copy.deepcopy(metrics),
                }
    
    with torch.no_grad():
        with tqdm(range(0, e_N, BATCH_SIZE)) as pbar:
        # for start_idx in tqdm(range(0, e_N, BATCH_SIZE)):
            for start_idx in pbar:
                pbar.set_description("[test]")
                  
                end_idx = min(start_idx + BATCH_SIZE, N)
                data = data_tr[e_idxlist[start_idx:end_idx]]
                heldout_data = data_te[e_idxlist[start_idx:end_idx]]

                u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
                data = data[u_idxlist_wo_any_iteracts]
                heldout_data = heldout_data[u_idxlist_wo_any_iteracts]

                data_tensor = naive_sparse2tensor(data).to(device)
                
                n_batch_user = data.shape[0]
                non_zero_indices = data.nonzero()

                # if TOTAL_ANNEAL_STEPS > 0:
                #     anneal = min(ANNEAL_CAP, 
                #                    1. * update_count / TOTAL_ANNEAL_STEPS)
                # else:
                #     anneal = ANNEAL_CAP

                # encoding
                start = time.perf_counter()
                mu, logvar = model.encode(data_tensor)
                t_encode = time.perf_counter() - start
                
                # decoding
                start = time.perf_counter()
                recon_batch = model.decode(mu)  
                t_decode = time.perf_counter() - start
                
                recon_batch_clone = recon_batch.clone()
                
                start = time.perf_counter()
                recon_batch = recon_batch.numpy()
                # recon_batch[data.nonzero()] = -np.inf
                t_to_cpu = time.perf_counter() - start

                # loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
                # total_loss += loss.item()
                # recon_batch = recon_batch.cpu().numpy()
                
                # https://stackoverflow.com/questions/59338537/summarize-non-zero-values-in-a-scipy-matrix-by-axis
                n_already_intaract_item = data.indptr[1:] - data.indptr[:-1]
                max_n_already_intaract_item = int(np.max(n_already_intaract_item))

                for l in range(n_sampling):
                    torch.manual_seed(l)
                    
                    # bluring z
                    start = time.perf_counter()
                    z_blurred = reparameterize(mu, logvar)
                    t_blurred = time.perf_counter() - start
                    
                    # Stochastic multi-VAE
                    start = time.perf_counter()
                    recon_batch_blurred = model.decode(z_blurred)
                    recon_batch_blurred = recon_batch_blurred.numpy()
                    # recon_batch_blurred[non_zero_indices] = -np.inf
                    t_decode_blurred = time.perf_counter() - start
                    
                    start = time.perf_counter()
                    topk_indexes = metric.get_idx_topk(recon_batch_blurred, max_n_already_intaract_item+20)
                    t_get_idx_topk = time.perf_counter() - start
                    
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(np.array(range(1,max_n_already_intaract_item+20+1), dtype=float)), axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf

                    metrics_dic["multi-VAE-Stochastic"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic"]["prediction_time"][l].append((t_encode + t_blurred + t_decode_blurred + t_get_idx_topk)/n_batch_user)
                    
                    
                    # Stochastic multi-VAE with Faiss NNS
                    start = time.perf_counter()
                    z_dash_blurred = torch.tanh(torch.add(torch.matmul(z_blurred,P0.T),p0_bias))
                    z_dash_blurred_wi_constant = torch.column_stack((z_dash_blurred,torch.ones(n_batch_user, device=device)))
                    topk_scores, topk_indexes = index_flat.search(z_dash_blurred_wi_constant, max_n_already_intaract_item+20) # perform searching on GPU
                    # topk_indexes, topk_scores = scann_brute.search_batched(z_dash_blurred_wi_constant, max_n_already_intaract_item+20)
                    # topk_indexes = list(map(lambda i: [ele for ele in topk_indexes[i] if ele not in non_zero_indices[1][non_zero_indices[0] == i]][:20], range(n_batch_user)))
                    t_k_nns = time.perf_counter() - start
                    
                    # create dummy recon_batch which include the corresponded score to top-k items
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf

                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch_by_topk_indexes(topk_indexes, heldout_data, n_batch_user, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["prediction_time"][l].append((t_encode + t_blurred + t_k_nns)/n_batch_user)
                    
                    # Stochastic multi-VAE with Faiss NNS(IVF)
                    start = time.perf_counter()
                    z_dash_blurred = torch.tanh(torch.add(torch.matmul(z_blurred,P0.T),p0_bias))
                    z_dash_blurred_wi_constant = torch.column_stack((z_dash_blurred,torch.ones(n_batch_user, device=device)))
                    topk_scores, topk_indexes = index_ivf.search(z_dash_blurred_wi_constant, max_n_already_intaract_item+20)
                    # topk_indexes, topk_scores = scann_brute.search_batched(z_dash_blurred_wi_constant, max_n_already_intaract_item+20)
                    # topk_indexes = list(map(lambda i: [ele for ele in topk_indexes[i] if ele not in non_zero_indices[1][non_zero_indices[0] == i]][:20], range(n_batch_user)))
                    t_k_nns = time.perf_counter() - start
                    
                    # create dummy recon_batch which include the corresponded score to top-k items
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf

                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch_by_topk_indexes(topk_indexes, heldout_data, n_batch_user, 20))
                    metrics_dic["multi-VAE-Stochastic-Faiss-IVF"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss-IVF"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss-IVF"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    metrics_dic["multi-VAE-Stochastic-Faiss-IVF"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss-IVF"]["prediction_time"][l].append((t_encode + t_blurred + t_k_nns)/n_batch_user)
                    
                    # multi-VAE + Gumbel Max Sampling
                    start = time.perf_counter() 
                    # recon_batch_gumbel_sampled = recon_batch + np.vectorize(gumbel_inverse)(np.random.uniform(size=recon_batch.shape))
                    # https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
                    # recon_batch_gumbel_sampled = recon_batch - torch.empty_like(recon_batch, memory_format=torch.legacy_contiguous_format).exponential_().log()
                    recon_batch_gumbel_sampled = recon_batch_clone - beta * (-torch.rand(recon_batch_clone.shape, device=device).log()).log()
                    recon_batch_gumbel_sampled = recon_batch_gumbel_sampled.numpy()
                    # recon_batch_gumbel_sampled[non_zero_indices] = -np.inf
                    t_gumbel_sampling = time.perf_counter() - start
                    
                    start = time.perf_counter()
                    topk_indexes = metric.get_idx_topk(recon_batch_gumbel_sampled, max_n_already_intaract_item+20)
                    t_get_idx_topk = time.perf_counter() - start
                    
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(np.array(range(1,max_n_already_intaract_item+20+1), dtype=float)), axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.array(range(1,21), dtype=float), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE-Gumbel"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Gumbel"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Gumbel"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Gumbel"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Gumbel"]["prediction_time"][l].append((t_encode + t_decode + t_gumbel_sampling + t_get_idx_topk)/n_batch_user)
                    
                    # multi-VAE + Faiss
                    start = time.perf_counter()
                    z_dash = torch.tanh(torch.add(torch.matmul(mu,P0.T),p0_bias))
                    z_dash_wi_constant = torch.column_stack((z_dash, torch.ones(n_batch_user, device=device)))
                    topk_scores, topk_indexes = index_flat.search(z_dash_wi_constant, max_n_already_intaract_item+20)
                    # topk_indexes = list(map(lambda i: [ele for ele in topk_indexes[i] if ele not in non_zero_indices[1][non_zero_indices[0] == i]][:20], range(n_batch_user)))
                    t_k_nns = time.perf_counter() - start
                    
                    # create dummy recon_batch which include the corresponded score to top-k items
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Faiss"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Faiss"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Faiss"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Faiss"]["prediction_time"][l].append((t_encode + t_k_nns)/n_batch_user)
                    
#                     # multi-VAE + ScaNN
#                     start = time.perf_counter()
#                     z_dash = torch.tanh(torch.add(torch.matmul(mu,P0.T),p0_bias))
#                     z_dash_wi_constant = torch.column_stack((z_dash, torch.ones(n_batch_user, device=device)))
#                     # topk_scores, topk_indexes = gpu_index_flat.search(z_dash_wi_constant.cpu(), max_n_already_intaract_item+20) # perform searching on GPU
#                     topk_indexes, topk_scores = scann_brute.search_batched(z_dash_wi_constant, max_n_already_intaract_item+20)
#                     # topk_indexes = list(map(lambda i: [ele for ele in topk_indexes[i] if ele not in non_zero_indices[1][non_zero_indices[0] == i]][:20], range(n_batch_user)))
#                     t_k_nns = time.perf_counter() - start
                    
#                     # create dummy recon_batch which include the corresponded score to top-k items
#                     recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
#                     np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
#                     # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
#                     recon_batch_dummy[data.nonzero()] = -np.inf
                    
#                     metrics_dic["multi-VAE-ScaNN"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
#                     # metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
#                     metrics_dic["multi-VAE-ScaNN"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
#                     # metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
#                     metrics_dic["multi-VAE-ScaNN"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
#                     # metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
#                     metrics_dic["multi-VAE-ScaNN"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
#                     # metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
#                     metrics_dic["multi-VAE-ScaNN"]["prediction_time"][l].append((t_encode + t_k_nns)/n_batch_user)
                    
                    # multi-VAE + Faiss(IVF)
                    start = time.perf_counter()
                    z_dash = torch.tanh(torch.add(torch.matmul(mu,P0.T),p0_bias))
                    z_dash_wi_constant = torch.column_stack((z_dash, torch.ones(n_batch_user, device=device)))
                    topk_scores, topk_indexes = index_ivf.search(z_dash_wi_constant, max_n_already_intaract_item+20) # perform searching on GPU
                    # topk_indexes, topk_scores = scann_brute.search_batched(z_dash_wi_constant, max_n_already_intaract_item+20)
                    # topk_indexes = list(map(lambda i: [ele for ele in topk_indexes[i] if ele not in non_zero_indices[1][non_zero_indices[0] == i]][:20], range(n_batch_user)))
                    t_k_nns = time.perf_counter() - start
                    
                    # create dummy recon_batch which include the corresponded score to top-k items
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE-Faiss-IVF"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Faiss-IVF"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Faiss-IVF"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    metrics_dic["multi-VAE-Faiss-IVF"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Faiss-IVF"]["prediction_time"][l].append((t_encode + t_k_nns)/n_batch_user)
                    
                    # multi-VAE
                    start = time.perf_counter()
                    topk_indexes = metric.get_idx_topk(recon_batch, max_n_already_intaract_item+20)
                    t_get_idx_topk = time.perf_counter() - start
                    
                    recon_batch_dummy = np.ones((n_batch_user,n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(np.array(range(1,max_n_already_intaract_item+20+1), dtype=float)), axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.array(range(1,21), dtype=float), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 100))
                    metrics_dic["multi-VAE"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch, heldout_data, 50))
                    metrics_dic["multi-VAE"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch, heldout_data, 50))
                    metrics_dic["multi-VAE"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch, heldout_data, 100))
                    metrics_dic["multi-VAE"]["prediction_time"][l].append((t_encode + t_decode + t_to_cpu + t_get_idx_topk)/n_batch_user)
                    
    
    # total_loss /= len(range(0, e_N, BATCH_SIZE))
    
    for method_name, metrics in metrics_dic.items():
        for metric_name, metric_list in metrics.items():
            if metric_name == "prediction_time":
                continue
            for l in range(n_sampling):
                metric_list[l] = np.concatenate(metric_list[l])

    return total_loss, metrics_dic

In [49]:
# Run on test data.
total_loss, metrics_dic = evaluate_stochastic(test_data_tr, test_data_te, n_sampling=50)
# print('=' * 89)
# print('| End of training | test loss {:4.2f} | n1 {:4.3f}({:4.3f}) | n100 {:4.3f}({:4.3f}) | r20 {:4.3f}({:4.3f}) | '
#         'r50 {:4.3f}({:4.3f})'.format(test_loss, np.mean(n1_list), np.std(n1_list)/np.sqrt(len(n1_list)), np.mean(n100_list), np.std(n100_list)/np.sqrt(len(n100_list)), np.mean(r20_list), np.std(r20_list)/np.sqrt(len(r20_list)), np.mean(r50_list), np.std(r50_list)/np.sqrt(len(r50_list))))
# print('=' * 89)

[test]: 100%|██████████| 20/20 [1:13:21<00:00, 220.09s/it]


In [53]:
method_names = list(metrics_dic.keys())
metric_names = list(metrics_dic[method_names[0]].keys())

results_all = []
for method_name, metrics in metrics_dic.items():
    results = []
    for metric_name, metric_list in metrics.items():
        if metric_name == "prediction_time":
            results.append("{:6f}({:6f})".format(np.mean(metrics_dic[method_name][metric_name]), np.std(metrics_dic[method_name][metric_name])))
        else:
        # # results.append(np.mean(metrics_dic[method_name][metric_name]))
            results.append("{:4.4f}({:4.4f})".format(np.mean(metrics_dic[method_name][metric_name]), np.std(metrics_dic[method_name][metric_name])))
    results_all.append(results)
results_all = pd.DataFrame(results_all, columns=metric_names, index=method_names)
print("All")
results_all

All


Unnamed: 0,ndcg@20,recall@20,precision@20,hit_rate@20,prediction_time
multi-VAE,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000614(0.000031)
multi-VAE-Faiss,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000493(0.000046)
multi-VAE-Faiss-IVF,0.3273(0.2258),0.3467(0.2772),0.1672(0.1499),0.8620(0.3449),0.000333(0.000017)
multi-VAE-Gumbel,0.3079(0.2102),0.3506(0.2762),0.1634(0.1382),0.8725(0.3336),0.000848(0.000035)
multi-VAE-Stochastic,0.2948(0.2110),0.3268(0.2666),0.1591(0.1428),0.8529(0.3542),0.000628(0.000047)
multi-VAE-Stochastic-Faiss,0.2948(0.2110),0.3268(0.2666),0.1591(0.1428),0.8529(0.3542),0.000498(0.000053)
multi-VAE-Stochastic-Faiss-IVF,0.2894(0.2129),0.3140(0.2661),0.1538(0.1423),0.8409(0.3658),0.000337(0.000020)


In [51]:
# method_names = list(metrics_dic.keys())
# metric_names = list(metrics_dic[method_names[0]].keys())

results_top20per = []
for method_name, metrics in metrics_dic.items():
    results = []
    for metric_name, metric_list in metrics.items():
        if metric_name == "prediction_time":
            results.append("{:6f}({:6f})".format(np.mean(metrics_dic[method_name][metric_name]), np.std(metrics_dic[method_name][metric_name])))
        else:
            results.append("{:4.4f}({:4.4f})".format(np.mean(np.quantile(metrics_dic[method_name][metric_name], 0.8, axis=0)), np.std(np.quantile(metrics_dic[method_name][metric_name], 0.8, axis=0))))
            # results.append(np.mean(np.quantile(metrics_dic[method_name][metric_name], 0.8, axis=0)))
    results_top20per.append(results)
results_top20per = pd.DataFrame(results_top20per, columns=metric_names, index=method_names)
print("Top 20%")
results_top20per

Top 20%


Unnamed: 0,ndcg@20,recall@20,precision@20,hit_rate@20,prediction_time
multi-VAE,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000614(0.000031)
multi-VAE-Faiss,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000493(0.000046)
multi-VAE-Faiss-IVF,0.3273(0.2258),0.3467(0.2772),0.1672(0.1499),0.8620(0.3449),0.000333(0.000017)
multi-VAE-Gumbel,0.3647(0.2190),0.3916(0.2778),0.1849(0.1484),0.8991(0.3007),0.000848(0.000035)
multi-VAE-Stochastic,0.3644(0.2136),0.3956(0.2728),0.1862(0.1504),0.9110(0.2837),0.000628(0.000047)
multi-VAE-Stochastic-Faiss,0.3644(0.2136),0.3956(0.2728),0.1862(0.1504),0.9110(0.2837),0.000498(0.000053)
multi-VAE-Stochastic-Faiss-IVF,0.3543(0.2183),0.3767(0.2759),0.1782(0.1498),0.8941(0.3068),0.000337(0.000020)


In [52]:
results_bottom20per = []
for method_name, metrics in metrics_dic.items():
    results = []
    for metric_name, metric_list in metrics.items():
        if metric_name == "prediction_time":
            results.append("{:6f}({:6f})".format(np.mean(metrics_dic[method_name][metric_name]), np.std(metrics_dic[method_name][metric_name])))
        else:
            results.append("{:4.4f}({:4.4f})".format(np.mean(np.quantile(metrics_dic[method_name][metric_name], 0.2, axis=0)), np.std(np.quantile(metrics_dic[method_name][metric_name], 0.2, axis=0))))
            # results.append(np.mean(np.quantile(metrics_dic[method_name][metric_name], 0.2, axis=0)))
    results_bottom20per.append(results)
results_bottom20per = pd.DataFrame(results_bottom20per, columns=metric_names, index=method_names)
print("Bottom 20%")
results_bottom20per

Bottom 20%


Unnamed: 0,ndcg@20,recall@20,precision@20,hit_rate@20,prediction_time
multi-VAE,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000614(0.000031)
multi-VAE-Faiss,0.3390(0.2239),0.3656(0.2772),0.1759(0.1524),0.8777(0.3276),0.000493(0.000046)
multi-VAE-Faiss-IVF,0.3273(0.2258),0.3467(0.2772),0.1672(0.1499),0.8620(0.3449),0.000333(0.000017)
multi-VAE-Gumbel,0.2480(0.1765),0.3072(0.2646),0.1413(0.1208),0.8465(0.3600),0.000848(0.000035)
multi-VAE-Stochastic,0.2222(0.1726),0.2564(0.2329),0.1316(0.1265),0.8025(0.3970),0.000628(0.000047)
multi-VAE-Stochastic-Faiss,0.2222(0.1726),0.2564(0.2329),0.1316(0.1265),0.8025(0.3970),0.000498(0.000053)
multi-VAE-Stochastic-Faiss-IVF,0.2216(0.1765),0.2499(0.2328),0.1291(0.1277),0.7929(0.4042),0.000337(0.000020)


In [102]:
beta = 0.2
# beta = 1
# beta_dash = 0.2

def gumbel_inverse(x):
    return -beta*np.log(-np.log(x))

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps.mul(std).add_(mu)

def evaluate_stochastic_naive(data_tr, data_te, n_sampling=1):
    # Turn on evaluation mode
    model.eval()
    total_loss = 0.0
    global update_count
    e_idxlist = list(range(data_tr.shape[0]))
    e_N = data_tr.shape[0]
    
    metrics = {"ndcg@20":[[] for _ in range(n_sampling)],
               # "ndcg@100":[[] for _ in range(n_sampling)],
               # "recall@20":[[] for _ in range(n_sampling)],
               # "recall@50":[[] for _ in range(n_sampling)],
               # "precision@20":[[] for _ in range(n_sampling)],
               # "precision@50":[[] for _ in range(n_sampling)],
               # "hit_rate@20" : [[] for _ in range(n_sampling)],
               # "hit_rate@100" : [[] for _ in range(n_sampling)],
               "prediction_time": [[] for _ in range(n_sampling)],
              }
    
    metrics_dic = {
        "multi-VAE":copy.deepcopy(metrics),
        "multi-VAE-Gumbel":copy.deepcopy(metrics),
        # "multi-VAE-Gumbel-low-beta":copy.deepcopy(metrics),
        "multi-VAE-Stochastic":copy.deepcopy(metrics),
        "multi-VAE-Stochastic-Faiss":copy.deepcopy(metrics),
                }
    
    with torch.no_grad():
        with tqdm(range(0, e_N, BATCH_SIZE)) as pbar:
        # for start_idx in tqdm(range(0, e_N, BATCH_SIZE)):
            for start_idx in pbar:
                pbar.set_description("[test]")
                  
                end_idx = min(start_idx + BATCH_SIZE, N)
                data = data_tr[e_idxlist[start_idx:end_idx]]
                heldout_data = data_te[e_idxlist[start_idx:end_idx]]

                u_idxlist_wo_any_iteracts = [i for i, x in enumerate(heldout_data.toarray().sum(axis=1)) if x >0]
                data = data[u_idxlist_wo_any_iteracts]
                heldout_data = heldout_data[u_idxlist_wo_any_iteracts]

                data_tensor = naive_sparse2tensor(data).to(device)

                # if TOTAL_ANNEAL_STEPS > 0:
                #     anneal = min(ANNEAL_CAP, 
                #                    1. * update_count / TOTAL_ANNEAL_STEPS)
                # else:
                #     anneal = ANNEAL_CAP

                # encoding
                start = time.perf_counter()
                mu, logvar = model.encode(data_tensor)
                t_encode = time.perf_counter() - start

                # loss = criterion(recon_batch, data_tensor, mu, logvar, anneal)
                # total_loss += loss.item()
                # recon_batch = recon_batch.cpu().numpy()
                # 

                for l in range(n_sampling):
                    # Stochastic multi-VAE with Faiss NNS
                    torch.manual_seed(l)
                    start = time.perf_counter()
                    z_blurred = reparameterize(mu, logvar)
                    z_dash_blurred = torch.tanh(torch.add(torch.matmul(z_blurred,P0.T),p0_bias))
                    z_dash_blurred_wi_constant = torch.column_stack((z_dash_blurred,torch.ones(z_dash_blurred.shape[0], device=device)))
                    # topk_scores, topk_indexes = gpu_index_flat.search(z_dash_blurred_wi_constant.cpu(), 200)
                    topk_indexes, topk_scores = scann_brute.search_batched(z_dash_blurred_wi_constant, 200)
                    t_decode = time.perf_counter() - start
                    
                    recon_batch_dummy = np.ones((z_blurred.shape[0],n_items)) * (-np.inf)
                    np.put_along_axis(recon_batch_dummy, topk_indexes, topk_scores, axis=1)
                    # np.put_along_axis(recon_batch_dummy, topk_indexes, np.reciprocal(topk_dists), axis=1)
                    recon_batch_dummy[data.nonzero()] = -np.inf

                    metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_dummy, heldout_data, 50))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic-Faiss"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_dummy, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic-Faiss"]["prediction_time"][l].append((t_encode + t_decode)/recon_batch_dummy.shape[0])
                    print('---')
                    print(t_encode)
                    print(t_decode)
                    
                    torch.manual_seed(l)
                    # Stochastic multi-VAE
                    start = time.perf_counter()
                    z_blurred = reparameterize(mu, logvar)
                    recon_batch_blurred = model.decode(z_blurred)
                    recon_batch_blurred = recon_batch_blurred.numpy()
                    t_decode = time.perf_counter() - start
                    
                    recon_batch_blurred[data.nonzero()] = -np.inf

                    metrics_dic["multi-VAE-Stochastic"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    # metrics_dic["multi-VAE-Stochastic"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    # metrics_dic["multi-VAE-Stochastic"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_blurred, heldout_data, 50))
                    # metrics_dic["multi-VAE-Stochastic"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_blurred, heldout_data, 20))
                    # metrics_dic["multi-VAE-Stochastic"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_blurred, heldout_data, 100))
                    metrics_dic["multi-VAE-Stochastic"]["prediction_time"][l].append((t_encode + t_decode)/recon_batch_blurred.shape[0])
                    print(t_decode)
                    
                    
                    
                    # multi-VAE + Gumbel Max Sampling
                    # recon_batch_gumbel_sampled = recon_batch + np.vectorize(gumbel_inverse)(np.random.uniform(size=recon_batch.shape))
                    # https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
                    # recon_batch_gumbel_sampled = recon_batch - torch.empty_like(recon_batch, memory_format=torch.legacy_contiguous_format).exponential_().log()
                    start = time.perf_counter()
                    recon_batch = model.decode(mu)  
                    recon_batch_gumbel_sampled = recon_batch - beta * (-torch.rand(recon_batch.shape, device=device).log()).log()
                    recon_batch_gumbel_sampled = recon_batch_gumbel_sampled.numpy()
                    t_decode = time.perf_counter() - start
                    
                    recon_batch_gumbel_sampled[data.nonzero()] = -np.inf
                    
                    metrics_dic["multi-VAE-Gumbel"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    # metrics_dic["multi-VAE-Gumbel"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    # metrics_dic["multi-VAE-Gumbel"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 50))
                    # metrics_dic["multi-VAE-Gumbel"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 20))
                    # metrics_dic["multi-VAE-Gumbel"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_gumbel_sampled, heldout_data, 100))
                    metrics_dic["multi-VAE-Gumbel"]["prediction_time"][l].append((t_encode + t_decode)/recon_batch_gumbel_sampled.shape[0])
                    
                    # multi-VAE
                    start = time.perf_counter()
                    recon_batch = model.decode(mu)  
                    recon_batch = recon_batch.numpy()
                    t_decode = time.perf_counter() - start
                    
                    recon_batch[data.nonzero()] = -np.inf
                    metrics_dic["multi-VAE"]["ndcg@20"][l].append(metric.NDCG_binary_at_k_batch(recon_batch, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["ndcg@100"][l].append(metric.NDCG_binary_at_k_batch(recon_batch_cpu, heldout_data, 100))
                    # metrics_dic["multi-VAE"]["recall@20"][l].append(metric.Recall_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["recall@50"][l].append(metric.Recall_at_k_batch(recon_batch_cpu, heldout_data, 50))
                    # metrics_dic["multi-VAE"]["precision@20"][l].append(metric.Precision_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["precision@50"][l].append(metric.Precision_at_k_batch(recon_batch_cpu, heldout_data, 50))
                    # metrics_dic["multi-VAE"]["hit_rate@20"][l].append(metric.HitRate_at_k_batch(recon_batch_cpu, heldout_data, 20))
                    # metrics_dic["multi-VAE"]["hit_rate@100"][l].append(metric.HitRate_at_k_batch(recon_batch_cpu, heldout_data, 100))
                    metrics_dic["multi-VAE"]["prediction_time"][l].append((t_encode + t_decode)/recon_batch.shape[0])
                    print(t_decode)
    
    # total_loss /= len(range(0, e_N, BATCH_SIZE))
    
    for method_name, metrics in metrics_dic.items():
        for metric_name, metric_list in metrics.items():
            if metric_name == "prediction_time":
                continue
            for l in range(n_sampling):
                metric_list[l] = np.concatenate(metric_list[l])

    return total_loss, metrics_dic

In [103]:
total_loss_naive, metrics_dic_naive = evaluate_stochastic_naive(test_data_tr, test_data_te, n_sampling=10)

[test]:   0%|          | 0/20 [00:00<?, ?it/s]

---
0.10165012600191403
0.20238748300471343
0.10031771200010553
0.09559305400762241
---
0.10165012600191403
0.20341147101134993
0.09514349499659147
0.0977189169934718
---
0.10165012600191403
0.1947029730072245
0.09814140500384383
0.09816164898802526
---
0.10165012600191403
0.20039231599366758
0.09703384699241724
0.09781384198868182
---
0.10165012600191403
0.20374920299218502
0.09622806799598038
0.09781208899221383


[test]:   0%|          | 0/20 [00:05<?, ?it/s]

---
0.10165012600191403
0.20513160800328478





KeyboardInterrupt: 

In [98]:
method_names = list(metrics_dic_naive.keys())
metric_names = list(metrics_dic_naive[method_names[0]].keys())

results_all = []
for method_name, metrics in metrics_dic_naive.items():
    results = []
    for metric_name, metric_list in metrics.items():
        if metric_name == "prediction_time":
            results.append("{:6.6f}({:6.6f})".format(np.mean(metric_list), np.std(metric_list)))
        else:
        # # results.append(np.mean(metrics_dic[method_name][metric_name]))
            results.append("{:4.4f}({:4.4f})".format(np.mean(metric_list), np.std(metric_list)))
    results_all.append(results)
results_all = pd.DataFrame(results_all, columns=metric_names, index=method_names)
print("All")
results_all

All


Unnamed: 0,ndcg@20,prediction_time
multi-VAE,0.3390(0.2239),0.000394(0.000035)
multi-VAE-Gumbel,0.3081(0.2102),0.000629(0.000046)
multi-VAE-Stochastic,0.2952(0.2113),0.000397(0.000035)
multi-VAE-Stochastic-Faiss,0.2952(0.2113),0.000589(0.000016)


In [99]:
results_top20per = []
for method_name, metrics in metrics_dic_naive.items():
    results = []
    for metric_name, metric_list in metrics.items():
        if metric_name == "prediction_time":
            results.append("{:4.4f}({:4.4f})".format(np.mean(metric_list), np.std(metric_list)))
        else:
            results.append("{:4.4f}({:4.4f})".format(np.mean(np.quantile(metric_list, 0.8, axis=0)), np.std(np.quantile(metric_list, 0.8, axis=0))))
            # results.append(np.mean(np.quantile(metrics_dic[method_name][metric_name], 0.8, axis=0)))
    results_top20per.append(results)
results_top20per = pd.DataFrame(results_top20per, columns=metric_names, index=method_names)
print("Top 20%")
results_top20per

Top 20%


Unnamed: 0,ndcg@20,prediction_time
multi-VAE,0.3390(0.2239),0.0004(0.0000)
multi-VAE-Gumbel,0.3585(0.2177),0.0006(0.0000)
multi-VAE-Stochastic,0.3568(0.2140),0.0004(0.0000)
multi-VAE-Stochastic-Faiss,0.3568(0.2140),0.0006(0.0000)
