In [1]:
import argparse
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data
from scipy.sparse import csr_matrix
from collections import OrderedDict, defaultdict#, Iterable
import datetime
import pandas as pd
import scipy.sparse as sp
from torch.utils.data import dataloader
#from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [2]:
path='C:/Users/Zhouziyue/workspace/tensorflow2.5_torch2.1/AAAI_feedback/ml-100k'

In [3]:
def parse_args(arg):
    parser = argparse.ArgumentParser(description='Model Params')

    # for gcn
    parser.add_argument('--embed_dim', default=64, type=int)
    parser.add_argument('--layer_num', default=3, type=int)

    # for ssl
    parser.add_argument('--SSL_reg', default=0.1, type=float)
    parser.add_argument('--SSL_dropout_ratio', default=0.1, type=float)
    parser.add_argument('--SSL_temp', default=0.2, type=float)

    # for train
    parser.add_argument('--batch_size', default=2048, type=int)
    parser.add_argument('--epoch_num', default=500, type=int)
    parser.add_argument('--stop_cnt', default=10, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--reg', default=0.0001, type=float)

    # for test
    parser.add_argument('--k', default=10, type=int)

    # for save and read
    parser.add_argument('--train_data_path', default=path+'/train_imp.csv', type=str)
    parser.add_argument('--test_data_path', default=path+'/test_imp.csv', type=str)

    return parser.parse_args(arg)


args = parse_args(arg=[])

In [4]:

#from params import args


class LightGCN(nn.Module):
    def __init__(self, user_num, item_num, embed_dim, layer_num):
        super(LightGCN, self).__init__()
        self.user_num = user_num
        self.item_num = item_num
        self.embed_dim = embed_dim
        self.layer_num = layer_num
        self.dropout = nn.Dropout(p=0.1)

        self.user_embedding = nn.Embedding(self.user_num, self.embed_dim)
        self.item_embedding = nn.Embedding(self.item_num, self.embed_dim)

        self.reset_params()

    def reset_params(self):
        init = torch.nn.init.xavier_uniform_
        init(self.user_embedding.weight)
        init(self.item_embedding.weight)

    def forward(self, norm_adj):
        ego_embedding = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
        all_embedding = [ego_embedding]

        for i in range(self.layer_num):
            ego_embedding = torch.sparse.mm(norm_adj, ego_embedding)
            all_embedding += [ego_embedding]

        all_embedding = torch.stack(all_embedding, dim=1).mean(dim=1)
        user_embedding, item_embedding = torch.split(all_embedding, [self.user_num, self.item_num], dim=0)

        return user_embedding, item_embedding






In [5]:

#from params import args


def sp_mat_to_tensor(sp_mat):
    coo = sp_mat.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.asarray([coo.row, coo.col]))
    return torch.sparse_coo_tensor(indices, coo.data, coo.shape).coalesce()


def inner_product(x1, x2):
    return torch.sum(torch.mul(x1, x2), dim=-1)


def compute_bpr_loss(x1, x2):
    #return -torch.sum(torch.log((x1.view(-1) - x2.view(-1)).sigmoid() + 1e-8))
    return -torch.sum(F.logsigmoid(x1-x2))

def compute_infoNCE_loss(x1, x2, temp):
    return torch.logsumexp((x2 - x1[:, None]) / temp, dim=1)


def compute_reg_loss(w1, w2, w3):
    return 0.5 * torch.sum(torch.pow(w1, 2) + torch.pow(w2, 2) + torch.pow(w3, 2))


def compute_metric(ratings, test_item):
    hit = 0
    DCG = 0.
    iDCG = 0.

    _, shoot_index = torch.topk(ratings, args.k)
    shoot_index = shoot_index.cpu().tolist()

    for i in range(len(shoot_index)):
        if shoot_index[i] in test_item:
            hit += 1
            DCG += 1 / np.log2(i + 2)
        if i < test_item.size()[0]:
            iDCG += 1 / np.log2(i + 2)

    recall = hit / test_item.size()[0]
    NDCG = DCG / iDCG

    return recall, NDCG


def hr(rank, k):
    """Hit Rate.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: hit rate.
    """
    res = 0.0
    for r in rank:
        if r < k:
            res += 1
    return res / len(rank)


def mrr(rank, k):
    """Mean Reciprocal Rank.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: mrr.
    """
    mrr = 0.0
    for r in rank:
        if r < k:
            mrr += 1 / (r + 1)
    return mrr / len(rank)


def ndcg(rank, k):
    """Normalized Discounted Cumulative Gain.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: ndcg.
    """
    res = 0.0
    for r in rank:
        if r < k:
            res += 1 / np.log2(r + 2)
    return res / len(rank)

def my_compute_metric(ratings, test_item):
    hit = 0
    DCG = 0.
    iDCG = 0.

    _, shoot_index = torch.topk(ratings, args.k)
    shoot_index = shoot_index.cpu().tolist()

    rank=[]
    for target in test_item:
        if target in list(shoot_index):
            rank.append(list(shoot_index).index(target))
        else:
            rank.append(1e10)

    res_1 = hr(rank, args.k)
    res_2 = ndcg(rank, args.k)
    res_3 = mrr(rank, args.k)
    res_4 = hr(rank, 1)
    

    return res_1,res_2,res_3,res_4

In [6]:

#from params import args


class RecDataset_train(data.Dataset):
    def __init__(self, data, user_num, item_num):
        self.data = data
        self.user_num = user_num
        self.item_num = item_num

        self.user_item_pair = self.data.values
        self.user_index = self.user_item_pair[:, 0].flatten()
        self.item_index = self.user_item_pair[:, 1].flatten()
        self.interact_num = len(self.user_item_pair)

        self.user_pos_dict = OrderedDict()
        grouped_user = self.data.groupby('user')
        for user, user_data in grouped_user:
            self.user_pos_dict[user] = user_data['item'].to_numpy(dtype=np.int32)

        self.user_list, self.pos_item_list, self.neg_item_list = self.sample()

    def sample(self):
        """
        Sample user, pos_item, neg_item
        """
        user_arr = np.array(list(self.user_pos_dict.keys()), dtype=np.int32)
        user_list = np.random.choice(user_arr, size=self.interact_num, replace=True)

        user_pos_len = defaultdict(int)
        for u in user_list:
            user_pos_len[u] += 1

        user_pos_sample = dict()
        user_neg_sample = dict()
        for user, pos_len in user_pos_len.items():
            pos_item = self.user_pos_dict[user]
            pos_idx = np.random.choice(pos_item, size=pos_len, replace=True)
            user_pos_sample[user] = list(pos_idx)

            neg_item = np.random.randint(low=0, high=self.item_num, size=pos_len)
            for i in range(len(neg_item)):
                idx = neg_item[i]
                while idx in pos_item:
                    idx = np.random.randint(low=0, high=self.item_num)
                neg_item[i] = idx
            user_neg_sample[user] = list(neg_item)

        pos_item_list = [user_pos_sample[user].pop() for user in user_list]
        neg_item_list = [user_neg_sample[user].pop() for user in user_list]
        return user_list, pos_item_list, neg_item_list

    def __len__(self):
        return self.interact_num

    def __getitem__(self, idx):
        return self.user_list[idx], self.pos_item_list[idx], self.neg_item_list[idx]


class RecDataset_test(data.Dataset):
    def __init__(self, data):
        self.data = data

        self.user_item_pair = self.data.values

        self.user_pos_dict = OrderedDict()
        grouped_user = self.data.groupby('user')
        for user, user_data in grouped_user:
            self.user_pos_dict[user] = user_data['item'].to_numpy(dtype=np.int32)

        self.user_list = np.array(list(self.user_pos_dict.keys()))  # 用户不重复

    def __len__(self):
        return self.user_list.shape[0]

    def __getitem__(self, idx):
        return self.user_list[idx]

def my_RecDataset_test(tensors_list, batch_size):
    for i in range(0, len(tensors_list[0]), batch_size):
        yield tensors_list[0][i:i + batch_size], tensors_list[1][i:i + batch_size]

In [7]:

#from LightGCN import LightGCN
#from dataset import RecDataset_train, RecDataset_test
#from utils import sp_mat_to_tensor, inner_product, compute_infoNCE_loss, compute_bpr_loss, compute_reg_loss, compute_metric


class Model:

    def __init__(self):
        self.train_data_path = args.train_data_path
        self.test_data_path = args.test_data_path
        self.behavior_mats = {}
        self.behavior_mats_T = {}

        now_time = datetime.datetime.now()
        self.time = datetime.datetime.strftime(now_time, '%Y_%m_%d__%H_%M_%S')

        self.epoch = 0
        self.cnt = 0
        self.train_loss = []
        self.bpr_loss = []
        self.infoNCE_loss = []
        self.reg_loss = []
        self.recall_history = []
        self.NDCG_history = []
        self.best_recall = 0
        self.best_NDCG = 0
        self.best_epoch = 0
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # Load data
        train_data = pd.read_csv(self.train_data_path,  header=0, names=['user', 'item'])
        test_data = pd.read_csv(self.test_data_path,  header=0, names=['user', 'item'])
        
        all_data = pd.concat([train_data, test_data])
        self.user_num = 943
        self.item_num = 1523

        self.train_dataset = RecDataset_train(train_data, self.user_num, self.item_num)
        self.train_loader = dataloader.DataLoader(self.train_dataset, batch_size=args.batch_size, shuffle=True,
                                                  num_workers=0, pin_memory=True)

        self.test_dataset = RecDataset_test(test_data)
        self.test_loader = dataloader.DataLoader(self.test_dataset, batch_size=args.batch_size, shuffle=True,
                                                 num_workers=0, pin_memory=True)

        # Model Config
        self.embed_dim = args.embed_dim
        self.layer_num = args.layer_num
        self.lr = args.lr
        self.model = LightGCN(self.user_num, self.item_num, self.embed_dim, self.layer_num).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.graph = self.create_adj_mat(is_subgraph=False)
        self.graph = sp_mat_to_tensor(self.graph).to(self.device)

    def run(self):
        for epoch in range(1, args.epoch_num + 1):
            self.epoch += 1

            epoch_loss, bpr_loss, infoNCE_loss, reg_loss= self.train_epoch()
            self.train_loss.append(epoch_loss)
            self.bpr_loss.append(bpr_loss)
            self.infoNCE_loss.append(infoNCE_loss)
            self.reg_loss.append(reg_loss)
            print(f"Epoch {self.epoch}:  loss:{epoch_loss/self.train_dataset.interact_num} \
                    bpr_loss:{bpr_loss/self.train_dataset.interact_num} \
                    info_NCE_loss:{infoNCE_loss/self.train_dataset.interact_num} \
                    reg_loss:{reg_loss/self.train_dataset.interact_num}")

            f_hr,f_mrr,f_ndcg,f_acc = self.test_epoch()
            #self.recall_history.append(epoch_recall)
            #self.NDCG_history.append(epoch_NDCG)
            print(f"Epoch {self.epoch}: hr: {f_hr}\n" + f"mrr: {f_mrr}\n" + f"mdcg: {f_ndcg}\n" + f"ac: {f_acc}")
            #print(f"Epoch {self.epoch}:  recall:{epoch_recall}, NDCG:{epoch_NDCG}")
            
            '''
            if epoch_recall > self.best_recall:
                self.cnt = 0
                self.best_recall = epoch_recall
                self.best_epoch = self.epoch
'''
            if f_ndcg > self.best_NDCG:
                self.cnt = 0
                self.best_NDCG = f_ndcg
                self.best_epoch = self.epoch
            else:
                self.cnt += 1
            '''    
            if epoch_recall < self.best_recall and epoch_NDCG < self.best_NDCG:
                self.cnt += 1'''

            #self.save_metrics()

            if self.cnt == args.stop_cnt:
                print(f"Early stop at {self.best_epoch}: best Recall: {self.best_recall}, best_NDCG: {self.best_NDCG}\n")
                #self.save_metrics()
                break

    def train_epoch(self):
        epoch_loss = 0
        epoch_bpr_loss = 0
        epoch_infoNCE_loss = 0
        epoch_reg_loss = 0
        sub_graph1 = self.create_adj_mat(is_subgraph=True)
        sub_graph1 = sp_mat_to_tensor(sub_graph1).to(self.device)
        sub_graph2 = self.create_adj_mat(is_subgraph=True)
        sub_graph2 = sp_mat_to_tensor(sub_graph2).to(self.device)

        for batch_user, batch_pos_item, batch_neg_item in tqdm(self.train_loader):
            batch_user = batch_user.long().to(self.device)
            batch_pos_item = batch_pos_item.long().to(self.device)
            batch_neg_item = batch_neg_item.long().to(self.device)

            all_user_embedding, all_item_embedding = self.model(self.graph)
            SSL_user_embedding1, SSL_item_embedding1 = self.model(sub_graph1)
            SSL_user_embedding2, SSL_item_embedding2 = self.model(sub_graph2)

            # 归一化，消除嵌入的模对相似度衡量的影响+
            SSL_user_embedding1 = F.normalize(SSL_user_embedding1)
            SSL_user_embedding2 = F.normalize(SSL_user_embedding2)
            SSL_item_embedding1 = F.normalize(SSL_item_embedding1)
            SSL_item_embedding2 = F.normalize(SSL_item_embedding2)

            batch_user_embedding = all_user_embedding[batch_user]
            batch_pos_item_embedding = all_item_embedding[batch_pos_item]
            batch_neg_item_embedding = all_item_embedding[batch_neg_item]
            batch_SSL_user_embedding1 = SSL_user_embedding1[batch_user]
            batch_SSL_user_embedding2 = SSL_user_embedding2[batch_user]
            batch_SSL_item_embedding1 = SSL_item_embedding1[batch_pos_item]
            batch_SSL_item_embedding2 = SSL_item_embedding2[batch_pos_item]

            # [batch_size]
            pos_score = inner_product(batch_user_embedding, batch_pos_item_embedding)  # [2048]
            neg_score = inner_product(batch_user_embedding, batch_neg_item_embedding)

            # [batch_size]
            SSL_user_pos_score = inner_product(batch_SSL_user_embedding1, batch_SSL_user_embedding2)  # 全1
            SSL_user_neg_score = torch.matmul(batch_SSL_user_embedding1, torch.transpose(SSL_user_embedding2, 0, 1))

            SSL_item_pos_score = inner_product(batch_SSL_item_embedding1, batch_SSL_item_embedding2)
            SSL_item_neg_score = torch.matmul(batch_SSL_item_embedding1, torch.transpose(SSL_item_embedding2, 0, 1))

            bpr_loss = compute_bpr_loss(pos_score, neg_score)  # 1419

            infoNCE_user_loss = compute_infoNCE_loss(SSL_user_pos_score, SSL_user_neg_score, args.SSL_temp)
            infoNCE_item_loss = compute_infoNCE_loss(SSL_item_pos_score, SSL_item_neg_score, args.SSL_temp)
            infoNCE_loss = torch.sum(infoNCE_user_loss + infoNCE_item_loss, dim=-1)  # 22375

            reg_loss = compute_reg_loss(  # 11
                self.model.user_embedding(batch_user),
                self.model.item_embedding(batch_pos_item),
                self.model.item_embedding(batch_neg_item)
            )

            loss = bpr_loss + infoNCE_loss * args.SSL_reg + reg_loss * args.reg  # 3657
            epoch_loss += loss
            epoch_bpr_loss += bpr_loss
            epoch_infoNCE_loss += infoNCE_loss * args.SSL_reg
            epoch_reg_loss += reg_loss * args.reg
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return epoch_loss, epoch_bpr_loss, epoch_infoNCE_loss, epoch_reg_loss

    def test_epoch(self):
        test_user_pos_dict = self.test_dataset.user_pos_dict
        train_user_pos_dict = self.train_dataset.user_pos_dict

        epoch_recall = 0
        epoch_NDCG = 0
        tot = 0
        
        hr_b=[]
        mrr_b=[]
        ndcg_b=[]
        acc_b=[]
        #test_data = pd.read_csv(self.test_data_path,  header=0, names=['user', 'item'])
        df_test = pd.read_csv(self.test_data_path)
        #self.testUniqueUsers = pd.unique(df_test["users"])
        test_user = df_test["users"].to_numpy()
        test_item = df_test["pos_item"].to_numpy()
        for batch_users,batch_target in my_RecDataset_test([test_user,test_item],args.batch_size):
            #user_num = test_user.size()[0]
            
            test_user = batch_users#.long().to(self.device)
            test_item = batch_target#.long().to(self.device)

            all_user_embedding, all_item_embedding = self.model(self.graph)
            test_user_embedding = all_user_embedding[test_user].unsqueeze(1)
            
            tar_item_embeds = all_item_embedding[test_item].unsqueeze(1)#[b,1,h]
            #neg_indx = torch.randint(low=1, high=self.item_num, size=(tar_item_embeds.shape[0], 1000))#.to(users.device)
            neg_indx = torch.arange(self.item_num).cuda().unsqueeze(0)
            neg_item_embeds = all_item_embedding[neg_indx].expand(tar_item_embeds.shape[0], -1, -1)
            test_item_embeding=torch.cat([tar_item_embeds,neg_item_embeds],dim=1)#[b,1+1000,h]
            
            ratings = -torch.matmul(test_user_embedding, test_item_embeding.transpose(1,2)).squeeze(1)
            print(ratings.shape)
            rank = ratings.argsort().argsort()[:, 0]
            rank=rank.cpu()
            res_1 = hr(rank, args.k)
            res_2 = ndcg(rank, args.k)
            res_3 = mrr(rank, args.k)
            res_4 = hr(rank, 1)
            
            hr_b.append(res_1)
            ndcg_b.append(res_2)
            mrr_b.append(res_3)
            acc_b.append(res_4)
        f_hr=np.mean(hr_b)
        f_mrr=np.mean(mrr_b)
        f_ndcg=np.mean(ndcg_b)
        f_acc=np.mean(acc_b)
        return f_hr,f_mrr,f_ndcg,f_acc

        '''
                epoch_recall += recall
                epoch_NDCG += NDCG

            tot += user_num

        epoch_recall /= tot
        epoch_NDCG /= tot

        return epoch_recall, epoch_NDCG'''
        

    def create_adj_mat(self, is_subgraph):
        node_num = self.user_num + self.item_num
        user_np, item_np = self.train_dataset.user_index, self.train_dataset.item_index

        if is_subgraph:
            sample_size = int(user_np.shape[0]*(1-args.SSL_dropout_ratio))
            keep_index = np.arange(user_np.shape[0])
            np.random.shuffle(keep_index)
            keep_index = keep_index[:sample_size]
            # keep_index = np.random.randint(user_np.shape[0], size=3*sample_size)
            # keep_index = np.unique(keep_index)
            # keep_index = keep_index[:sample_size]
            # keep_idx = np.random.randint(user_np.shape[0], size=int(user_np.shape[0]*(1-args.SSL_dropout_ratio)))
            # keep_idx = np.random.choice(user_np, size=int(user_np.shape[0]*(1-args.SSL_dropout_ratio)), replace=False)
            user_np = np.array(user_np)[keep_index]
            item_np = np.array(item_np)[keep_index]
            ratings = np.ones_like(user_np)
            tmp_adj = sp.csr_matrix((ratings, (user_np, item_np + self.user_num)), shape=(node_num, node_num))


            # keep_idx = np.random.choice(user, size=int(len(user) * (1 - args.SSL_dropout_ratio)), replace=True)
            # keep_idx.tolist()
            # sub_user = np.array(user)[keep_idx]
            # sub_item = np.array(item)[keep_idx]
            # # rating = np.ones_like(sub_user, dtype=np.float32)
            # c = np.ones_like(sub_user, dtype=np.float32)
            # c = torch.ones(sub_user.shape[0])
            # # tmp_adj = sp.csr_matrix((rating, (sub_user, sub_item + self.user_num)), shape=(node_num, node_num))
            # a = sp.csr_matrix( (c, (sub_user, sub_item + self.user_num)), shape=(node_num, node_num))
            # b = sp.csr_matrix((c, (sub_user, sub_item)), shape=(node_num, node_num))
            # tmp_adj = sp.csr_matrix((c, (sub_user, sub_item + self.user_num)), shape=(node_num, node_num))
        else:
            rating = np.ones_like(user_np, dtype=np.float32)
            tmp_adj = sp.csr_matrix((rating, (user_np, item_np + self.user_num)), shape=(node_num, node_num))
        adj = tmp_adj + tmp_adj.T

        row_sum = np.array(adj.sum(1))
        d = np.power(row_sum, -0.5).flatten()
        d[np.isinf(d)] = 0.
        d_mat = sp.diags(d)
        norm_adj = d_mat.dot(adj)
        norm_adj = norm_adj.dot(d_mat)

        return norm_adj

    def save_metrics(self):
        path = './runs/' + self.time + '/' + str(self.epoch) + '/'
        writer = SummaryWriter(path)
        for i in range(self.epoch):
            writer.add_scalar('Loss', self.train_loss[i], i)
            writer.add_scalar('bpr_loss', self.bpr_loss[i], i)
            writer.add_scalar('infoNCE_loss', self.infoNCE_loss[i], i)
            writer.add_scalar('reg_loss', self.reg_loss[i], i)
            writer.add_scalar('Recall', self.recall_history[i], i)
            writer.add_scalar('NDCG', self.NDCG_history[i], i)


if __name__ == '__main__':
    model = Model()
    model.run()

  d = np.power(row_sum, -0.5).flatten()
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 31.81it/s]


Epoch 1:  loss:1.1667665243148804                     bpr_loss:0.6929677724838257                     info_NCE_loss:0.4737841784954071                     reg_loss:1.4638380889664404e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 1: hr: 0.009160400113507378
mrr: 0.0027245802339166403
mdcg: 0.004202731710023997
ac: 0.0008643991424872304


  d = np.power(row_sum, -0.5).flatten()
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 49.77it/s]


Epoch 2:  loss:1.1557632684707642                     bpr_loss:0.6929296851158142                     info_NCE_loss:0.4628184139728546                     reg_loss:1.5258934581652284e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 2: hr: 0.008581362842295687
mrr: 0.0029516429640352726
mdcg: 0.004251249194132869
ac: 0.0012306100799872304


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 59.04it/s]


Epoch 3:  loss:1.1530426740646362                     bpr_loss:0.6929125785827637                     info_NCE_loss:0.46011435985565186                     reg_loss:1.5706897102063522e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 3: hr: 0.009516704096020148
mrr: 0.003133482998237014
mdcg: 0.004609241754970336
ac: 0.0011085397674872304


  d = np.power(row_sum, -0.5).flatten()
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 63.27it/s]


Epoch 4:  loss:1.1518820524215698                     bpr_loss:0.6928815245628357                     info_NCE_loss:0.45898449420928955                     reg_loss:1.603641976544168e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])


  d = np.power(row_sum, -0.5).flatten()


Epoch 4: hr: 0.010747314176007378
mrr: 0.0032438396010547876
mdcg: 0.004968529221140588
ac: 0.0010986328125


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.14it/s]


Epoch 5:  loss:1.1509991884231567                     bpr_loss:0.6928528547286987                     info_NCE_loss:0.4581299424171448                     reg_loss:1.635970147617627e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 5: hr: 0.010483359641032917
mrr: 0.0032098894007503986
mdcg: 0.004883601643292063
ac: 0.0011085397674872304


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.44it/s]


Epoch 6:  loss:1.1503863334655762                     bpr_loss:0.6928264498710632                     info_NCE_loss:0.4575430750846863                     reg_loss:1.668325421633199e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 6: hr: 0.010310161127270148
mrr: 0.003005857113748789
mdcg: 0.004683424514166973
ac: 0.0009353412537244608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 55.44it/s]


Epoch 7:  loss:1.1500027179718018                     bpr_loss:0.6927969455718994                     info_NCE_loss:0.4571888744831085                     reg_loss:1.6999398212647066e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 7: hr: 0.011745284012308456
mrr: 0.0032598094549030066
mdcg: 0.00518941364333534
ac: 0.0011184467224744608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.17it/s]


Epoch 8:  loss:1.1497489213943481                     bpr_loss:0.6927688121795654                     info_NCE_loss:0.45696282386779785                     reg_loss:1.7326718079857528e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 8: hr: 0.011977924255994608
mrr: 0.003305127378553152
mdcg: 0.005279787487053325
ac: 0.0011283536774616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.58it/s]


Epoch 9:  loss:1.1494731903076172                     bpr_loss:0.6927383542060852                     info_NCE_loss:0.45671725273132324                     reg_loss:1.766632885846775e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 9: hr: 0.0123738560584563
mrr: 0.003281597513705492
mdcg: 0.005353484565868139
ac: 0.0010062833649616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.10it/s]


Epoch 10:  loss:1.1493616104125977                     bpr_loss:0.6927056908607483                     info_NCE_loss:0.4566379189491272                     reg_loss:1.8002088836510666e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 10: hr: 0.0126790318397063
mrr: 0.0034621329978108406
mdcg: 0.005572557515187851
ac: 0.0010062833649616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.08it/s]


Epoch 11:  loss:1.1492412090301514                     bpr_loss:0.6926743388175964                     info_NCE_loss:0.45654842257499695                     reg_loss:1.8348249795963056e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 11: hr: 0.013023835440905221
mrr: 0.0035627884790301323
mdcg: 0.005723661905543005
ac: 0.0010574115662244608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.23it/s]


Epoch 12:  loss:1.1492000818252563                     bpr_loss:0.6926417946815491                     info_NCE_loss:0.4565395712852478                     reg_loss:1.8708038624026813e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 12: hr: 0.013592965757129682
mrr: 0.0037184187676757574
mdcg: 0.005958169031300277
ac: 0.0011893888337116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 61.83it/s]


Epoch 13:  loss:1.1489607095718384                     bpr_loss:0.6926087141036987                     info_NCE_loss:0.4563329815864563                     reg_loss:1.9083416191278957e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 13: hr: 0.013898141538379682
mrr: 0.00366798834875226
mdcg: 0.005991510219588032
ac: 0.0009963764099744608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 61.95it/s]


Epoch 14:  loss:1.1488455533981323                     bpr_loss:0.6925750374794006                     info_NCE_loss:0.4562511146068573                     reg_loss:1.9483301002765074e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 14: hr: 0.013756257315905221
mrr: 0.0037348573096096516
mdcg: 0.006014813442893657
ac: 0.0010574115662244608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 56.99it/s]


Epoch 15:  loss:1.1487425565719604                     bpr_loss:0.6925369501113892                     info_NCE_loss:0.4561856687068939                     reg_loss:1.9896180674550124e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 15: hr: 0.014071340052142452
mrr: 0.0037189668510109186
mdcg: 0.006072893950321149
ac: 0.0008743060974744608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.73it/s]


Epoch 16:  loss:1.1486220359802246                     bpr_loss:0.6924975514411926                     info_NCE_loss:0.4561041295528412                     reg_loss:2.0321727788541466e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 16: hr: 0.013705129114642452
mrr: 0.0036207628436386585
mdcg: 0.005913642939175886
ac: 0.0009353412537244608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 65.12it/s]


Epoch 17:  loss:1.1486238241195679                     bpr_loss:0.6924579739570618                     info_NCE_loss:0.45614513754844666                     reg_loss:2.076464625133667e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 17: hr: 0.013084870597155221
mrr: 0.003568127518519759
mdcg: 0.005736348792234974
ac: 0.0010574115662244608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.69it/s]


Epoch 18:  loss:1.1484177112579346                     bpr_loss:0.6924228668212891                     info_NCE_loss:0.4559735357761383                     reg_loss:2.1221298084128648e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 18: hr: 0.013734849979604143
mrr: 0.003665993455797434
mdcg: 0.005964333841474472
ac: 0.0009452482087116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 54.78it/s]


Epoch 19:  loss:1.1483875513076782                     bpr_loss:0.6923805475234985                     info_NCE_loss:0.4559854567050934                     reg_loss:2.1703635866288096e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 19: hr: 0.013705129114642452
mrr: 0.0037692650221288204
mdcg: 0.006039019742750091
ac: 0.0010062833649616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 63.65it/s]


Epoch 20:  loss:1.1483889818191528                     bpr_loss:0.6923336386680603                     info_NCE_loss:0.45603302121162415                     reg_loss:2.2210870156413876e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 20: hr: 0.013541837555866913
mrr: 0.003801964223384857
mdcg: 0.006028887000758612
ac: 0.0011283536774616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 63.50it/s]


Epoch 21:  loss:1.1482484340667725                     bpr_loss:0.6922836303710938                     info_NCE_loss:0.4559420049190521                     reg_loss:2.2729154807166196e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 21: hr: 0.013949269739642452
mrr: 0.0038090581074357033
mdcg: 0.006115290492704477
ac: 0.0011893888337116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.02it/s]


Epoch 22:  loss:1.1481765508651733                     bpr_loss:0.6922346949577332                     info_NCE_loss:0.45591846108436584                     reg_loss:2.3283813789021224e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 22: hr: 0.013663907868366913
mrr: 0.0037777358666062355
mdcg: 0.006035000833398399
ac: 0.0011283536774616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 55.06it/s]


Epoch 23:  loss:1.1481282711029053                     bpr_loss:0.692180871963501                     info_NCE_loss:0.4559234082698822                     reg_loss:2.3866437913966365e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 23: hr: 0.013583058802142452
mrr: 0.00383245712146163
mdcg: 0.006062001257752943
ac: 0.0011893888337116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.86it/s]


Epoch 24:  loss:1.1480152606964111                     bpr_loss:0.6921230554580688                     info_NCE_loss:0.4558674693107605                     reg_loss:2.4490065698046237e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 24: hr: 0.014101060917104143
mrr: 0.0038953153416514397
mdcg: 0.006224360895557196
ac: 0.0011893888337116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.78it/s]


Epoch 25:  loss:1.1480287313461304                     bpr_loss:0.6920600533485413                     info_NCE_loss:0.4559435248374939                     reg_loss:2.5155341063509695e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 25: hr: 0.013654000913379682
mrr: 0.0037917050067335367
mdcg: 0.006049998633985358
ac: 0.0011184467224744608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 59.86it/s]


Epoch 26:  loss:1.1479324102401733                     bpr_loss:0.6919934153556824                     info_NCE_loss:0.4559133052825928                     reg_loss:2.584051799203735e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 26: hr: 0.013277883020892452
mrr: 0.0037559913471341133
mdcg: 0.005938849441841898
ac: 0.0011283536774616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.06it/s]


Epoch 27:  loss:1.1478930711746216                     bpr_loss:0.6919291615486145                     info_NCE_loss:0.45593729615211487                     reg_loss:2.655639582371805e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 27: hr: 0.013267976065905221
mrr: 0.0036260171327739954
mdcg: 0.00582849917428125
ac: 0.0010062833649616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.91it/s]


Epoch 28:  loss:1.1477744579315186                     bpr_loss:0.6918514966964722                     info_NCE_loss:0.4558955430984497                     reg_loss:2.7313242753734812e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 28: hr: 0.013084870597155221
mrr: 0.0035907160490751266
mdcg: 0.005757904791835887
ac: 0.0009452482087116913


  d = np.power(row_sum, -0.5).flatten()
100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 62.84it/s]


Epoch 29:  loss:1.1476410627365112                     bpr_loss:0.6917708516120911                     info_NCE_loss:0.45584243535995483                     reg_loss:2.8126163670094684e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 29: hr: 0.01288195121843076
mrr: 0.0035187010653316975
mdcg: 0.005655008611696023
ac: 0.0009963764099744608


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 61.58it/s]


Epoch 30:  loss:1.1475892066955566                     bpr_loss:0.6916858553886414                     info_NCE_loss:0.4558742046356201                     reg_loss:2.8995247703278437e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 30: hr: 0.012647717548417991
mrr: 0.00340525945648551
mdcg: 0.005519729875054562
ac: 0.0008842130524616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 58.84it/s]


Epoch 31:  loss:1.1474910974502563                     bpr_loss:0.6915915012359619                     info_NCE_loss:0.45586955547332764                     reg_loss:2.992329427797813e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 31: hr: 0.012708752704667991
mrr: 0.003447805531322956
mdcg: 0.005564208507830702
ac: 0.0009452482087116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 64.37it/s]


Epoch 32:  loss:1.1474425792694092                     bpr_loss:0.6914907097816467                     info_NCE_loss:0.4559208154678345                     reg_loss:3.0885476007824764e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 32: hr: 0.01275988090593076
mrr: 0.0034170730505138636
mdcg: 0.00554746032003626
ac: 0.0009452482087116913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 55.75it/s]


Epoch 33:  loss:1.1473667621612549                     bpr_loss:0.6913731694221497                     info_NCE_loss:0.4559614956378937                     reg_loss:3.192499571014196e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])


  d = np.power(row_sum, -0.5).flatten()


torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 33: hr: 0.0121297154334563
mrr: 0.003371839877218008
mdcg: 0.0053748239620786935
ac: 0.0010062833649616913


100%|██████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 63.39it/s]


Epoch 34:  loss:1.1471834182739258                     bpr_loss:0.6912523508071899                     info_NCE_loss:0.45589807629585266                     reg_loss:3.305011705379002e-05
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([2048, 1524])
torch.Size([1762, 1524])
Epoch 34: hr: 0.011987831210981839
mrr: 0.0031529979314655066
mdcg: 0.005158099826981134
ac: 0.0009452482087116913
Early stop at 24: best Recall: 0, best_NDCG: 0.006224360895557196

