In [1]:
import torch
import torch.nn as nn

class GMF(nn.Module):
    def __init__(self, num_users, num_items, hidden_dim):
        super(GMF,self).__init__()
        self.gmf_user_emb = nn.Embedding(num_users, hidden_dim)
        self.gmf_item_emb = nn.Embedding(num_items, hidden_dim)
        self.gmf_linear = nn.Linear(hidden_dim, 1)
  
    def forward(self, user_indices, item_indices):
        u = self.gmf_user_emb(user_indices)
        i = self.gmf_item_emb(item_indices)
        element_product = torch.mul(u, i)
        logits = self.gmf_linear(element_product)
        result = torch.sigmoid(logits)
        return result

In [2]:
import pandas as pd
from utils.data import SampleGenerator

ratings = pd.read_csv('./dataset/ratings.csv')
ratings = ratings.rename(columns={'movieId': 'itemId'})

userId = list(set(ratings.userId))
new_userId = list(range(0,len(userId)))
df = pd.DataFrame({'userId':userId,'new_userId':new_userId})
ratings = pd.merge(ratings,df,how='left', on='userId')

itemId = list(set(ratings.itemId))
new_itemId = list(range(0,len(itemId)))
df = pd.DataFrame({'itemId':itemId,'new_itemId':new_itemId})
ratings = pd.merge(ratings,df,how='left', on='itemId')

ratings = ratings.drop(['userId', 'itemId'],axis = 1)
ratings = ratings.rename(columns={'new_userId':'userId', 'new_itemId':'itemId'})

data = SampleGenerator(ratings, implicit=True)
hidden_dim = 128
lr = 0.01
batch_size = 2048
epochs = 15

num_users = data.num_users
num_items = data.num_items
num_negatives_train = 5
num_negatives_test = 500

cuda =  torch.cuda.is_available()


model = GMF(num_users, num_items, hidden_dim)
criterion = nn.BCELoss()
optim = torch.optim.Adam(model.parameters(), lr)

if cuda:
    model.cuda()
    


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  ratings['rating'][ratings['rating'] >0] = 1.0


In [3]:
import os
from utils.eval import Evaluation

if not os.path.exists("./checkpoint"):
    os.mkdir("./checkpoint")

test_loader, negative_loader = data.instance_test_loader(num_negatives = num_negatives_test, batch_size = batch_size)
  
for epoch in range(1,epochs+1):

    train_loader = data.instance_a_train_loader(num_negatives=num_negatives_train, batch_size=batch_size)
    total_loss = 0
    for batch_id, batch in enumerate(train_loader):
        user, item, rating = batch[0], batch[1], batch[2]
        rating = rating.float()
        if cuda:
            user, item, rating = user.cuda(), item.cuda(), rating.cuda()
        optim.zero_grad()
        pred = model(user,item)
        loss = criterion(pred.view(-1), rating)
        loss.backward()
        optim.step()
        total_loss += loss.item()
    print("epoch{0} loss:{1:.4f}".format(epoch, total_loss))
  
    torch.save(model.state_dict(), "./checkpoint/gmf.pt")
  
    with torch.no_grad():
        test_users, test_items, test_preds = list(), list(), list()
        neg_users, neg_items, neg_preds = list(), list(), list()

        for batch in test_loader:
            user, item = batch[0], batch[1]
            test_users += user.data.view(-1).tolist()
            test_items += item.data.view(-1).tolist()

            if cuda:
                user, item = user.cuda(), item.cuda()
            pred = model(user,item)
            if cuda:
                pred = pred.cpu()

            test_preds += pred.data.view(-1).tolist()

        for batch in negative_loader:
            user, item = batch[0], batch[1]
            neg_users += user.data.view(-1).tolist()
            neg_items += item.data.view(-1).tolist()

            if cuda:
                user, item = user.cuda(), item.cuda()
            pred = model(user,item)
            if cuda:
                pred =  pred.cpu()

            neg_preds += pred.data.view(-1).tolist()

        eval = Evaluation([test_users, test_items, test_preds,
                           neg_users, neg_items, neg_preds])
        eval.print_eval_score_k(10)

epoch1 loss:109.3049
recall@10:0.1792, prec@10:0.0410
epoch2 loss:88.3565
recall@10:0.4837, prec@10:0.1628
epoch3 loss:64.3846
recall@10:0.5625, prec@10:0.2273
epoch4 loss:51.1067
recall@10:0.5556, prec@10:0.2257
epoch5 loss:35.0028
recall@10:0.5506, prec@10:0.2207
epoch6 loss:23.7136
recall@10:0.5599, prec@10:0.2269
epoch7 loss:17.0887
recall@10:0.5626, prec@10:0.2297
epoch8 loss:13.4299
recall@10:0.5738, prec@10:0.2360
epoch9 loss:10.7152
recall@10:0.5781, prec@10:0.2391
epoch10 loss:8.8684
recall@10:0.5752, prec@10:0.2384
epoch11 loss:7.4999
recall@10:0.5776, prec@10:0.2411
epoch12 loss:6.5381
recall@10:0.5831, prec@10:0.2435
epoch13 loss:5.6737
recall@10:0.5860, prec@10:0.2462
epoch14 loss:4.8836
recall@10:0.5912, prec@10:0.2485
epoch15 loss:4.5576
recall@10:0.5886, prec@10:0.2479
