# GMF in PyTorch

In [1]:
%reload_ext blackcellmagic
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.extend(['../', '../bpr'])

In [37]:
from pprint import pprint

import numpy as np
import textwrap
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset as TorchDataset, DataLoader

from kutils import load_ml100k, Dataset

In [4]:
dataset = Dataset(load_ml100k('~/datasets/ml-100k/u.data'))

In [5]:
class GMF(nn.Module):
    def __init__(self, n_users, n_items, embed_size, device="cpu"):
        super().__init__()
        self.device = device
        self.embed_u = nn.Embedding(num_embeddings=n_users, embedding_dim=embed_size)
        self.embed_i = nn.Embedding(num_embeddings=n_items, embedding_dim=embed_size)
        self.fc = nn.Linear(embed_size, 1)

        self.embed_u.weight.to(device)
        self.embed_i.weight.to(device)

        nn.init.normal_(self.embed_u.weight)
        nn.init.normal_(self.embed_i.weight)

    def forward(self, x):

        user_id = x[:, 0]
        item_id = x[:, 1]
        user_id.to(self.device)
        item_id.to(self.device)
        multiplied = self.embed_u(user_id) * self.embed_i(item_id)
        x = torch.sigmoid(self.fc(multiplied))
        return x

In [6]:
class GMFDataset(TorchDataset):
    """Dataset for pointwise recommendation"""

    def __init__(self, pos_matrix, neg_matrix, n_negatives):
        self.pos_matrix = pos_matrix
        self.neg_matrix = neg_matrix
        self.n_negatives = n_negatives
        self.samples = self._create_samples()

    def _create_samples(self):
        batch = []
        for user in range(self.pos_matrix.shape[0]):
            pos_items = self.pos_matrix[user].indices
            batch.append(
                np.column_stack(
                    [np.full_like(pos_items, user), pos_items, np.ones_like(pos_items)]
                )
            )

            neg_items = np.random.choice(
                a=self.neg_matrix[user].indices,
                size=len(pos_items) * self.n_negatives,
                replace=True,
            )
            batch.append(
                np.column_stack(
                    [np.full_like(neg_items, user), neg_items, np.zeros_like(neg_items)]
                )
            )
        samples = torch.from_numpy(np.concatenate(batch)).type(torch.LongTensor)
        return samples
    
    def __len__(self):
        return self.pos_matrix.nnz * self.n_negatives

    def __getitem__(self, idx):
        assert self.samples is not None and len(self.samples) > 0
        return self.samples[idx]

In [38]:
def train(
    model,
    loader_train,
    epochs,
    criterion,
    optimizer,
    device="cpu",
    eval_metrics=True,
    loader_test=None,
):
    model.train()
    model.to(device)

    train_losses = np.zeros(epochs)

    for epoch in range(epochs):
        losses, n_losses = 0, 0
        for batch_idx, batch in enumerate(loader_train):
            batch.to(device)
            optimizer.zero_grad()
            prediction = model(batch)
            target = batch[:, 2].float()
            loss = criterion(prediction.view(-1), target.view(-1))
            loss.backward()
            optimizer.step()
            losses += loss.item()

        if eval_metrics:
            assert loader_test is not None, "loader_test is required"
            metrics = evaluate(model=model, loader_test=loader_test)

        train_losses[epoch] = losses / len(loader_train)
        metrics["loss"] = train_losses[epoch]

        s = textwrap.indent(pformat(metrics), prefix='  ')
        print(f"Epoch {epoch}\n{s}")

    return train_losses


def evaluate(model, loader_test, topk=10):
    model.eval()
    item_ids = np.arange(dataset.train_matrix.shape[1])
    precisions, recalls, hit_ratios = [], [], []
    for user in range(dataset.test_matrix.shape[0]):
        rated = dataset.test_matrix[user].indices
        with torch.no_grad():
            data = np.column_stack([np.full_like(item_ids, user), item_ids])
            predictions = model(torch.from_numpy(data)).reshape(-1)
            recommendations = torch.argsort(predictions)[-topk:].numpy()
            overlap = np.intersect1d(rated, recommendations)
            precision = len(overlap) / topk
            recall = len(overlap) / len(rated)
            hit_ratio = int(len(overlap) > 0)
            precisions.append(precision)
            recalls.append(recall)
            hit_ratios.append(hit_ratio)
    return {
        f"Precision@{topk}": np.mean(precisions),
        f"Recall@{topk}": np.mean(recalls),
        f"HitRatio@{topk}": np.mean(hit_ratios),
    }

In [42]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

model = GMF(
    n_users=dataset.n_users, n_items=dataset.n_items, embed_size=10, device=device
)

loader_train = DataLoader(
    GMFDataset(
        pos_matrix=dataset.train_matrix, neg_matrix=dataset.negatives, n_negatives=5
    ),
    shuffle=True,
    batch_size=200,
)


loader_test = DataLoader(
    GMFDataset(
        pos_matrix=dataset.test_matrix,
        neg_matrix=dataset.negatives,
        n_negatives=100,
    ),
    shuffle=True,
    batch_size=200,
)

train(
    model=model,
    loader_train=loader_train,
    epochs=10,
    criterion=nn.BCELoss(),
    optimizer=optim.Adam(model.parameters()),
    device=device,
    loader_test=loader_test,
)

Epoch 0
  {'HitRatio@10': 0.08589607635206786,
   'Precision@10': 0.009331919406150585,
   'Recall@10': 0.003996460316824586,
   'loss': 0.5334028985588949}
Epoch 1
  {'HitRatio@10': 0.0975609756097561,
   'Precision@10': 0.01102863202545069,
   'Recall@10': 0.004999832748415365,
   'loss': 0.45221020705672377}
Epoch 2
  {'HitRatio@10': 0.09544008483563096,
   'Precision@10': 0.010180275715800637,
   'Recall@10': 0.004300962045435114,
   'loss': 0.4507141452767273}
Epoch 3
  {'HitRatio@10': 0.13891834570519618,
   'Precision@10': 0.015482502651113469,
   'Recall@10': 0.008642128071555611,
   'loss': 0.45035033948094566}
Epoch 4
  {'HitRatio@10': 0.20572640509013787,
   'Precision@10': 0.026511134676564158,
   'Recall@10': 0.013534684409414972,
   'loss': 0.4479252805008115}
Epoch 5
  {'HitRatio@10': 0.26617179215270415,
   'Precision@10': 0.03552492046659597,
   'Recall@10': 0.017320498216882918,
   'loss': 0.43646731667336597}
Epoch 6
  {'HitRatio@10': 0.34358430540827145,
   'Precisi

array([0.5334029 , 0.45221021, 0.45071415, 0.45035034, 0.44792528,
       0.43646732, 0.40813142, 0.3740552 , 0.34826327, 0.33180366])