In [91]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [92]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [93]:
url = 'http://files.grouplens.org/datasets/movielens/ml-100k/u.data'
df = pd.read_csv(url, sep='\t', names=['user_id', 'item_id', 'rating', 'timestamp'])
ratings_matrix = df.pivot_table(index='user_id', columns='item_id', values='rating').fillna(0).values

In [94]:
class FISM(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super().__init__()
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)
        self.global_bias = nn.Parameter(torch.zeros(1))
        self.query_embeddings = nn.Embedding(num_items, embedding_dim)
        self.target_embeddings = nn.Embedding(num_items, embedding_dim)
    
    def forward(self, user, item_i, item_j, batch_score, n):
        user_bias = self.user_bias(user)
        item_bias = self.item_bias(item_i)

        query_emb = self.query_embeddings(item_j)
        target_emb = self.target_embeddings(item_i).unsqueeze(2)
        batch_score = batch_score.unsqueeze(1)
        
        sim_mat = torch.bmm(query_emb, target_emb)
        pred = torch.bmm(batch_score, sim_mat).squeeze(-1) / (n - 1)
        pred += self.global_bias + user_bias + item_bias
        return pred.squeeze(-1)

In [95]:
class FISM_simple(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super().__init__()
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)
        self.global_bias = nn.Parameter(torch.zeros(1))
        self.item_embeddings = nn.Embedding(num_items, embedding_dim)        
    
    def forward(self, user, item_i, item_j, batch_score, n):
        user_bias = self.user_bias(user)
        item_bias = self.item_bias(item_i)

        query_emb = self.item_embeddings(item_j)
        target_emb = self.item_embeddings(item_i).unsqueeze(2)
        batch_score = batch_score.unsqueeze(1)
        
        sim_mat = torch.bmm(query_emb, target_emb)
        pred = torch.bmm(batch_score, sim_mat).squeeze(-1) / (n - 1)
        pred += self.global_bias + user_bias + item_bias
        return pred.squeeze(-1)

In [96]:
def train(n_epoch, model, optimizer, loss_fn, device):
    model.to(device)
    loss_history = []

    for epoch in range(n_epoch):
        total_loss = 0

        for user_idx, row in enumerate(ratings_matrix):
            nonzero_idx = np.nonzero(row)[0]
            nonzero = torch.tensor(row[nonzero_idx], dtype=torch.float32, device=device)
            n = len(nonzero_idx)

            if n <= 1:
                continue

            batch_i = []
            batch_j = []
            batch_score = []
            for item_idx in range(n):
                target_idx = nonzero_idx[item_idx]
                other_idxs = nonzero_idx[:item_idx].tolist() + nonzero_idx[item_idx+1:].tolist()
                batch_i.append(target_idx)
                batch_j.append(other_idxs)
                batch_score.append(row[other_idxs])
            
            user_idx = torch.tensor([user_idx] * n, device=device)
            batch_i = torch.tensor(batch_i, device=device)
            batch_j = torch.tensor(batch_j, device=device)
            batch_score = torch.tensor(batch_score, dtype=torch.float32, device=device)

            optimizer.zero_grad()
            pred = model.forward(user_idx, batch_i, batch_j, batch_score, n)
            loss = loss_fn(pred, nonzero)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
        loss_history.append(total_loss)
        print(f'Epoch {epoch + 1}, Loss: {total_loss}')
    return loss_history

In [None]:
n_users, n_items = ratings_matrix.shape
fism = FISM(n_users, n_items, 20)
optimizer = optim.Adam(fism.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

fism_loss = train(30, fism, optimizer, loss_fn, device)

In [88]:
n_users, n_items = ratings_matrix.shape
fism_simple = FISM_simple(n_users, n_items, 20)
optimizer = optim.Adam(fism_simple.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

fism_simple_loss = train(30, fism_simple, optimizer, loss_fn, device)