In [1]:
import pandas as pd
import os
import numpy as np
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Data Process

In [2]:
user_size = 6040
item_size = 3760

In [3]:
DATA_DIR = './processed_data/ml-1m/'

In [4]:
train_df = pd.read_csv(os.path.join(DATA_DIR,'train.csv'), usecols = ['user','item'], header = 0)

In [5]:
movieId_set = set(np.arange(item_size))

u_dict = train_df.groupby('user')['item'].apply(list).to_dict()

In [6]:
class BCEDataset(Dataset):

    def __init__(self, users, items, labels):

        self.users = users
        self.items = items
        self.labels = labels

    def __len__(self):

        return len(self.users)

    def __getitem__(self, idx):

        user = self.users[idx]
        item = self.items[idx]
        label = self.labels[idx]

        sample = {'user':user, 'item':item, 'label':label}

        return sample

In [7]:
def bce_getTrain(N, train_batch_size):
        '''
        N: num of negative samples
        '''

        train_u = []
        train_i = []
        train_l = []
        
        u_list = train_df['user'].values.tolist()
        i_list = train_df['item'].values.tolist()
      

        for index in range(len(u_list)):
            u = u_list[index]
            i = i_list[index]
            train_u.extend([u]*(N+1))
            train_i.append(i)
            train_l.append(1)
            train_l.extend([0]*N)
            PositiveSet = set(u_dict[u]) 

            for t in range(N):# sample negative items
                neg_i = np.random.randint(0, item_size)
                while neg_i in PositiveSet:
                    neg_i = np.random.randint(0, item_size)
                train_i.append(neg_i)

        train_dataset = BCEDataset(train_u, train_i, train_l)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size = train_batch_size, 
                                      shuffle = True,
                                      num_workers = 4,
                                      pin_memory = True,
                                     )

        return train_dataloader

# Model

In [8]:
class GMF(nn.Module):
    
    def __init__(self, user_size, item_size, embed_size):
        super().__init__()
        
        self.user_size = user_size
        self.item_size = item_size
        self.embed_size = embed_size
        
        self.embedding_user = nn.Embedding(user_size, embed_size)
        nn.init.xavier_uniform_(self.embedding_user.weight)
        
        self.embedding_item = nn.Embedding(item_size, embed_size)
        nn.init.xavier_uniform_(self.embedding_item.weight)
        
        
    def forward(self, user, item):
        
        user_embedding = self.embedding_user(user)
        item_embedding = self.embedding_item(item)
        
        user_vec = user_embedding.view([-1, embed_size])
        item_vec = item_embedding.view([-1, embed_size])
        
        dot = torch.sum(torch.mul(user_vec, item_vec), dim = 1).view(-1)

        return dot

# Hyperparameters

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


train_batch_size = 128

embed_size = 16

learning_rate = 0.001

epochs = 20

# Training 

In [10]:
model = GMF(user_size, item_size, embed_size).to(device)

loss_function = nn.MSELoss()
#loss_function = nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [None]:
for epoch in range(epochs):
    
   
    train_dataloader = bce_getTrain(4, train_batch_size)
    
    print(epoch)
    for idx, batch_data in enumerate(train_dataloader):
        user = batch_data['user'].long().to(device)
        item = batch_data['item'].long().to(device)

        #label = batch_data['label'].long().to(device) # for BCE
        label = batch_data['label'].float().to(device) # for MSE
   
        
        model.zero_grad()
           
        x_ij = torch.sigmoid(model(user, item))
        
        loss = loss_function(x_ij, label)
             
        loss.backward()
        
        optimizer.step()

0
1
