In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import scipy.stats as stats
import pandas as pd
import utils
import os
import matplotlib.pyplot as plt
import time

from tqdm import tqdm
from utils import *

In [2]:
import torch
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [3]:
def BV_PR_gpu(A, D, gamma):
    
    AT = A.T
    
    #G1 = A.sum(0)
    G2 = (D@A).sum(0)
    
    W = (1.0 - gamma)* AT@D@A + gamma * torch.diag(G2).to(device)
    
    W = torch.inverse(W)
    
    W = W@AT@D
   
    
    return W

In [4]:
nums = {'pinterest-20':9916, 'yelp':25815, 'ml-20m':20720}

datasets = ['ml-20m', 'pinterest-20', 'yelp']


models = ['EASE','MultiVAE','NeuMF','itemKNN', 'ALS']

metrics = ['Recall', 'NDCG', 'AP']

In [5]:
n = 500 # sample size

repeats = 100

In [6]:
def get_estimate(model, dataset, estimator, n):
    
    temp_list = list()
    path = '../save_PR/' +'fix_sample_%d/'%n +dataset + '/'+estimator+'/' + model + '/'
    for re in range(repeats):

        R = np.load(path + '%d.npz'%re)['R']
        temp_list.append(R)

    temp_array = np.array(temp_list)
    
    return temp_array

In [7]:
for dataset in (datasets):
    
    N = nums[dataset]
    AA = utils.A_Nn(N, n)
    AA = torch.tensor(AA).float().to(device)

    for model in models:
        
        
        print(model)
       
        save_path = '../save_PR/' + 'fix_sample_%d/'%n + dataset + '/BV_MES/' + model + '/'

        if not os.path.exists(save_path):
            os.makedirs(save_path)
        

#         em_PR = np.ones(N)*1/N
#         em_PR = torch.tensor(em_PR).double()
        
        PR_MES = get_estimate(model, dataset, 'MES', n).mean(0)
        em_PR = torch.tensor(PR_MES).float()
        
        D = torch.diag(em_PR).to(device)

        W = BV_PR_gpu(AA, D, gamma = 0.01)
        W = W.cpu().numpy()
        
        for re in tqdm(range(repeats)):

            ru, Ru = fix_load_model(model, dataset, n, re)
            WW = W[:][ru]
            PR = WW.mean(0)

            np.savez(save_path+ '%d.npz'%re, R = PR)

EASE


100%|██████████| 100/100 [05:50<00:00,  3.50s/it]


MultiVAE


100%|██████████| 100/100 [05:48<00:00,  3.48s/it]


NeuMF


100%|██████████| 100/100 [05:50<00:00,  3.51s/it]


itemKNN


100%|██████████| 100/100 [05:55<00:00,  3.55s/it]


ALS


100%|██████████| 100/100 [05:50<00:00,  3.50s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

EASE


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

MultiVAE


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

NeuMF


100%|██████████| 100/100 [01:09<00:00,  1.45it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

itemKNN


100%|██████████| 100/100 [01:08<00:00,  1.45it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

ALS


100%|██████████| 100/100 [01:09<00:00,  1.43it/s]


EASE


100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


MultiVAE


100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


NeuMF


100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


itemKNN


100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


ALS


100%|██████████| 100/100 [01:30<00:00,  1.11it/s]
