In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import time

from config import bpr_config
from data_utils import RecDataset, DataProvider
from evaluation.rec_evaluator import RecEvaluator

In [2]:
epochs = bpr_config.epochs
batch_size = bpr_config.batch_size
emb_dim = bpr_config.emb_dim
device = bpr_config.device
eta = bpr_config.eta
weight_decay = bpr_config.weight_decay

In [3]:
# Hyper-parameters and datset-specific parameters
rec_dataset = RecDataset(bpr_config.dir_path)
all_users = rec_dataset.get_users()
all_items = rec_dataset.get_items()
num_users = rec_dataset.get_num_users()
num_items = rec_dataset.get_num_items()
bought_mask = rec_dataset.get_bought_mask().to(device)
eval_dict = rec_dataset.get_interaction_records("test")
bought_dict = rec_dataset.get_interaction_records()
train_ui = rec_dataset.get_user_item_pairs()

In [4]:
dp = DataProvider(device)
evaluator = RecEvaluator(eval_dict, None, device)
writer = SummaryWriter("runs/BPR-Adam")

In [5]:
class BPR(nn.Module):
    '''
     BPR Model
    '''
    def __init__(self, num_users, num_items, emb_dim, bought_mask):
        """
        Args:
            num_users: The number of users.
            num_items: The number of items.
            emb_dim: Embedding dimension of embedding layer.
        """
        super(BPR, self).__init__()
        self.user_emb = nn.Embedding(num_users, emb_dim)
        self.item_emb = nn.Embedding(num_items, emb_dim)
        self.bought_mask = bought_mask
        nn.init.normal_(self.user_emb.weight, std=0.01)
        nn.init.normal_(self.item_emb.weight, std=0.01)
        
    def loss(self, users, pos_items, neg_items):
        
        emb_users = self.user_emb(users)
        emb_pos_items = self.item_emb(pos_items)
        emb_neg_items = self.item_emb(neg_items)
        x_ui = torch.sum(emb_users * emb_pos_items, 1)
        x_uj = torch.sum(emb_users * emb_neg_items, 1)
        x_uij = x_ui - x_uj
        log_prob = torch.log(torch.sigmoid(x_uij)).mean()
        return -log_prob
    
    def forward(self, users, k = 10, delete_bought = True):
        return self.top_k_items_for_users(users, k, delete_bought)
        users_emb = self.user_emb(users)
        items_emb = self.item_emb(self.items)
        scores = torch.mm(users_emb, items_emb.t())
        scores = torch.sigmoid(scores)
        if delete_bought:
            scores[self.bought_mask[users].bool()] = scores.min()-1
        _,indices = torch.topk(scores, k, dim=1)
        return self.items[indices]
    
    def top_k_items_for_users(self, users, k = 5, delete_bought = True):
        """ Gets top k items for users.
        
        Args:
            users: Target users.
            k: The number of items to recommend for each user.
            delete_bought: Boolean indicating whether recommend items bought before
        
        Returns:
            A tensor containing top k items for target users.
        """
        
        users_emb = self.user_emb(users)
        items_emb = self.item_emb.weight
        scores = torch.mm(users_emb, items_emb.t())
        scores = torch.sigmoid(scores)
        if delete_bought:
            scores[self.bought_mask[users].bool()] = scores.min()-1
        _, top_k_items = torch.topk(scores, k)
        return top_k_items
    
    def rank_items_for_users(self, users, items=None):  
        '''
            For each user we have many candidate items, rank them based on scores.
        '''
        num_users = users.size()[0]
        if(items is None):
            items = self.items.repeat(num_users,1)
        num_items = items.size()[1]
        items_emb = self.item_emb(items)
        users_emb = self.user_emb(users).view(num_users,1,-1).repeat(1,num_items,1)
        scores = torch.mul(users_emb,items_emb).sum(2)
        scores = torch.sigmoid(scores)
        _,items_ind = torch.sort(scores, 1, descending=True)
        return torch.gather(items,1,items_ind)

In [6]:
model = BPR(num_users, num_items, emb_dim, bought_mask)
model = model.to(device)
# optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum = 0.9, weight_decay = weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=eta, weight_decay = weight_decay)

In [7]:
for epoch in range(epochs):
    model.train()
    time_start = time.time()
    if(epoch %5 == 0):
        train_data = dp.prepare_bpr_triplets(all_items, bought_dict, bpr_config.batch_size)
    loss_epoch = 0
    for users, pos_items, neg_items in train_data:
        loss = model.loss(users,pos_items,neg_items)
        loss_epoch += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    time_end = time.time()
    print(
            "\t[Epochs %d/%d] [epoch loss: %6.5f] [Time:%6.5f]"
            % (epoch+1,bpr_config.epochs, loss_epoch, time_end - time_start)
    ) 
    with torch.no_grad():
        res = evaluator.top_k_evaluation(model, [3,5,10])    
        ndcg3, precision3, hit3, map3, mrr3 = res[0]
        ndcg5, precision5, hit5, map5, mrr5 = res[1]
        ndcg10, precision10, hit10, map10, mrr10 = res[2]
        writer.add_scalar('Loss', loss_epoch,epoch)
        writer.add_scalar('Metrics/NDCG@3', ndcg3,epoch)
        writer.add_scalar('Metrics/NDCG@5', ndcg5,epoch)
        writer.add_scalar('Metrics/NDCG@10', ndcg10,epoch)
        writer.add_scalar('Metrics/Precision@3', precision3,epoch)
        writer.add_scalar('Metrics/Precision@5', precision5,epoch)
        writer.add_scalar('Metrics/Precision@10', precision10,epoch)
        writer.add_scalar('Metrics/Hit@3', hit3,epoch)
        writer.add_scalar('Metrics/Hit@5', hit5,epoch)
        writer.add_scalar('Metrics/Hit@10', hit10,epoch)
        writer.add_scalar('Metrics/MAP', map10,epoch)
        writer.add_scalar('Metrics/MRR', mrr10, epoch)
writer.close()

	[Epochs 1/200] [epoch loss: 478.24875] [Time:2.27315]
	[Epochs 2/200] [epoch loss: 473.61350] [Time:2.90516]
	[Epochs 3/200] [epoch loss: 451.62170] [Time:2.90628]
	[Epochs 4/200] [epoch loss: 411.18129] [Time:2.35845]
	[Epochs 5/200] [epoch loss: 364.65451] [Time:2.21662]
	[Epochs 6/200] [epoch loss: 327.26519] [Time:2.62271]
	[Epochs 7/200] [epoch loss: 296.45812] [Time:2.25556]
	[Epochs 8/200] [epoch loss: 274.53698] [Time:2.28421]
	[Epochs 9/200] [epoch loss: 259.13092] [Time:2.19936]
	[Epochs 10/200] [epoch loss: 248.15015] [Time:2.40050]
	[Epochs 11/200] [epoch loss: 247.17495] [Time:2.52729]
	[Epochs 12/200] [epoch loss: 239.92952] [Time:2.45936]
	[Epochs 13/200] [epoch loss: 234.43255] [Time:2.75039]
	[Epochs 14/200] [epoch loss: 230.03656] [Time:2.98786]
	[Epochs 15/200] [epoch loss: 226.37882] [Time:3.09285]
	[Epochs 16/200] [epoch loss: 228.08649] [Time:2.48400]
	[Epochs 17/200] [epoch loss: 222.93125] [Time:2.41042]
	[Epochs 18/200] [epoch loss: 218.97714] [Time:2.43231]
	

	[Epochs 147/200] [epoch loss: 121.38925] [Time:3.00074]
	[Epochs 148/200] [epoch loss: 117.44673] [Time:2.93821]
	[Epochs 149/200] [epoch loss: 114.15843] [Time:2.98322]
	[Epochs 150/200] [epoch loss: 111.31742] [Time:2.69824]
	[Epochs 151/200] [epoch loss: 127.24883] [Time:2.58521]
	[Epochs 152/200] [epoch loss: 122.14845] [Time:2.38126]
	[Epochs 153/200] [epoch loss: 118.26262] [Time:2.47794]
	[Epochs 154/200] [epoch loss: 115.02007] [Time:2.52003]
	[Epochs 155/200] [epoch loss: 112.21620] [Time:2.48532]
	[Epochs 156/200] [epoch loss: 128.04656] [Time:2.96758]
	[Epochs 157/200] [epoch loss: 122.82951] [Time:2.83874]
	[Epochs 158/200] [epoch loss: 118.87540] [Time:3.39596]
	[Epochs 159/200] [epoch loss: 115.59406] [Time:3.29333]
	[Epochs 160/200] [epoch loss: 112.77213] [Time:2.92961]
	[Epochs 161/200] [epoch loss: 127.57106] [Time:3.11313]
	[Epochs 162/200] [epoch loss: 122.35566] [Time:2.97733]
	[Epochs 163/200] [epoch loss: 118.38934] [Time:3.00661]
	[Epochs 164/200] [epoch loss: 