# Matrix Factorization
We will build model for the recommendation system using Matrix Factorization (MF).

---

In [1]:
# import packages
import pandas as pd
import numpy as np

import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.multiprocessing
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import warnings
warnings.filterwarnings("ignore", category=FutureWarning) 

---

In [2]:
ratings = pd.read_csv('./data/ratings.dat', sep = "::", 
                            names = ['UserID', 'MovieID', 'Rating', 'Timestamp'], 
                            engine='python')

movies = pd.read_csv('./data/movies.dat', sep = "::", 
                        names = ['MovieID', 'Title', 'Genres'], 
                        engine='python', encoding="ISO-8859-1")

users = pd.read_csv('./data/users.dat', sep = "::", 
                        names = ['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code'], 
                        engine='python')

We will normalize the `Rating` column and split the data into training set and validation set.

In [3]:
# Normalize the rating to the range (-2, 2)
ratings['Rating'] = ratings['Rating'] - 3

# Train test split the data
train_ratings, validation_ratings = train_test_split(
    ratings, test_size=0.1, random_state=42
)

Prepare the data in the `torch.utils.data.DataLoader` format for Pytorch.

In [4]:
# Map movie id and user id to indexes
movie_idx = {id: i for i, id in enumerate(movies['MovieID'])}
user_idx = {id: i for i, id in enumerate(users['UserID'])}

# Load the dataset
class Loader(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        user_id = self.df['UserID'].iloc[index]
        movie_id = self.df['MovieID'].iloc[index]
        rating = self.df['Rating'].iloc[index]
        user_by_idx = user_idx[user_id]
        movie_by_idx = movie_idx[movie_id]
        return user_by_idx, movie_by_idx, rating
    
training_data = Loader(train_ratings)
validation_data = Loader(validation_ratings)
batch_size = 1024
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=0)
validation_dataloader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, num_workers=0)

Now, we will implement a `MatrixFactorization` class. Given by user `i` and item `j`, the utility function between user `i` and item `j` can be described by:
$$
\hat{y}_{ij} \approx \mathbf{w}_i^T \mathbf{x}_j + b_i + d_j + a
$$
where:
- $b_i$ is the user `i`'s bias
- $d_j$ is the item `j`'s bias
- a is the global bias

We use `nn.Embedding` to store the embedding matrix for user and item and `Stochastic Gradient Descent` for model optimization.

In [5]:
LR = 1
WEIGHT_DECAY = 5e-5

class MatrixFactorization(pl.LightningModule):
    '''
    Pytorch Lightning class for Matrix Factorization
    Attributes:
        - n_users: number of users
        - n_items: number of movies
        - n_factors: number of embedding factors (embedding size, 40 by default)
    '''
    def __init__(self, n_users, n_items, n_factors = 40):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.n_factors = n_factors

        # Add bias
        self.user_biases = nn.Embedding(n_users, 1)
        self.item_biases = nn.Embedding(n_items, 1)
        self.bias = nn.Parameter(data=torch.rand(1))
        self.user_embeddings = nn.Embedding(n_users, n_factors)
        self.item_embeddings = nn.Embedding(n_items, n_factors)

    def forward(self, users, items):
        """
        Forward pass through the model. For a single user and item, this
        looks like:
        bias + user_bias + item_bias + user_embeddings.dot(item_embeddings)

        Parameters:
            - users: Array of user indices
            - items : Array of item indices
        Returns:
            - preds: Predicted ratings.
        """
        batch_user_embs = self.user_embeddings(users)
        batch_item_embs = self.item_embeddings(items)

        preds = torch.reshape(
            torch.diag(
                torch.matmul(batch_user_embs, torch.transpose(batch_item_embs, 0, 1))
            ),
            (-1, 1),
        )

        # add bias
        preds += self.user_biases(users) + self.item_biases(items) + self.bias

        # make sure the predicted ratings are constrained between (-2, 2)
        return torch.clip(preds.squeeze(), min=-2, max=2) 
    
    def training_step(self, batch, batch_idx):
        users, items, rating = batch
        rating = rating.to(torch.float32)
        output = self.forward(users, items)
        loss = F.mse_loss(rating, output)
        self.log("train_loss", loss)
        return {"loss": loss}
    
    def validation_step(self, batch, batch_idx):
        users, items, rating = batch
        rating = rating.to(torch.float32)
        output = self.forward(users, items)
        loss = F.mse_loss(rating, output)
        self.log('val_loss', loss)
        
        # Calculate mean absolute error for accuracy
        mae = torch.mean(torch.abs(rating - output))
        self.log("val_mae", mae)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = LR, weight_decay=WEIGHT_DECAY)
        return optimizer

In [6]:
n_users = len(user_idx)
n_movies = len(movie_idx)
n_factors = 40

model = MatrixFactorization(n_users, n_movies, n_factors)
trainer = pl.Trainer(accelerator='auto', max_epochs=100, logger=False)
result = trainer.fit(model, train_dataloader, validation_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/minhle/Documents/Recommendation-System/checkpoints exists and is not empty.

  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | user_biases     | Embedding | 6.0 K  | train
1 | item_biases     | Embedding | 3.9 K  | train
2 | user_embeddings | Embedding | 241 K  | train
3 | item_embeddings | Embedding | 155 K  | train
  | other params    | n/a       | 1      | n/a  
------------------------------------------------------
406 K     Trainable params
0         Non-trainable params
406 K     Total params
1.627     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [8]:
def eval_model(model, dataloader):
    total_loss = 0  # Initialize total_loss for RMSE
    total_mae = 0  # Initialize total_mae for MAE
    total_samples = 0  # Track the total number of samples
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for evaluation
        for users, items, ratings in dataloader:
            ratings = ratings.to(torch.float32)
            predictions = model(users, items)
            
            # Calculate and accumulate the total loss (MSE)
            total_loss += F.mse_loss(predictions, ratings, reduction='sum').item()
            
            # Calculate and accumulate the total MAE
            total_mae += torch.sum(torch.abs(predictions - ratings)).item()
            
            # Count the number of ratings
            total_samples += len(ratings)
    
    # Calculate RMSE and MAE
    rmse = (total_loss / total_samples) ** 0.5
    mae = total_mae / total_samples
    return rmse, mae

# Evaluate the model on training and validation data
train_rmse, train_mae = eval_model(model, train_dataloader)
val_rmse, val_mae = eval_model(model, validation_dataloader)

# Print the results
print(f"Train RMSE: {train_rmse:.3f}, Train MAE: {train_mae:.3f}")
print(f"Validation RMSE: {val_rmse:.3f}, Validation MAE: {val_mae:.3f}")


Train RMSE: 0.798, Train MAE: 0.643
Validation RMSE: 0.880, Validation MAE: 0.706


In [9]:
trainer.validate(model, validation_dataloader)

/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/minhle/Library/Python/3.9/lib/python/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.7739256024360657, 'val_mae': 0.705941379070282}]