In [1]:
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import numpy as np

from tqdm import tqdm

import scipy.stats as stats


from utils import *
import utils

import torch
import time

In [2]:
def update_gpu(xu, pi, P_Rr):
    '''
    xu is the rank of ru
    pi is the p_r
    P_Rr is the P(r|R) , shape is N x n
    '''
    temp = pi.view(-1, 1) * P_Rr
    temp_sum = temp.sum(0)

    bot = torch.reciprocal(temp_sum).repeat((N, 1))

    z_Rr = bot * temp
    z_rR = z_Rr.T
 
    pi = z_rR[xu]
    pi = pi.mean(0)

    return pi

In [3]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [4]:
nums = {'pinterest-20':9916, 'yelp':25815, 'ml-20m':20720}

datasets = ['ml-20m', 'pinterest-20', 'yelp']

models = ['EASE','MultiVAE','NeuMF','itemKNN', 'ALS']


metrics = ['Recall', 'NDCG', 'AP']

In [5]:
n = 500

repeats = 100

epoch = 100

In [6]:
for dataset in datasets:
    
    N = nums[dataset]
    AA = utils.A_Nn(N, n)
    P_Rr = torch.tensor(AA).to(device)
    
    for model in models:
        
        print(model)

        save_path = '../save_PR/' + 'fix_sample_%d/'%n + dataset + '/MLE/' + model + '/'
        
        if not os.path.exists(save_path):
            os.makedirs(save_path)
            
        for re in tqdm(range(repeats)):
            
            ru, _  = fix_load_model(model, dataset, n, re)
            
            xu = torch.tensor(ru).to(device)
            
            pi = torch.ones(N)/N
            pi = pi.to(device)
            
            for i in range(epoch):
                
                pi = update_gpu(xu, pi, P_Rr)
            
            np.savez(save_path + '%d.npz'%re, R = pi.cpu().numpy())

  0%|          | 0/100 [00:00<?, ?it/s]

EASE


100%|██████████| 100/100 [26:34<00:00, 15.94s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

MultiVAE


100%|██████████| 100/100 [24:55<00:00, 14.96s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

NeuMF


100%|██████████| 100/100 [27:12<00:00, 16.33s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

itemKNN


100%|██████████| 100/100 [32:56<00:00, 19.76s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

ALS


100%|██████████| 100/100 [29:27<00:00, 17.68s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

EASE


100%|██████████| 100/100 [05:41<00:00,  3.41s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

MultiVAE


100%|██████████| 100/100 [05:32<00:00,  3.33s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

NeuMF


100%|██████████| 100/100 [05:54<00:00,  3.55s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

itemKNN


100%|██████████| 100/100 [05:29<00:00,  3.30s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

ALS


100%|██████████| 100/100 [06:20<00:00,  3.80s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

EASE


100%|██████████| 100/100 [13:06<00:00,  7.86s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

MultiVAE


100%|██████████| 100/100 [13:17<00:00,  7.98s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

NeuMF


100%|██████████| 100/100 [13:13<00:00,  7.94s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

itemKNN


100%|██████████| 100/100 [12:50<00:00,  7.70s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

ALS


100%|██████████| 100/100 [13:30<00:00,  8.11s/it]
