In this notebook, we will work on improving the model and continue to test metrics to see if our model has improved. We will look at offline training and validation metrics.

## Model with Repetition Penalty

We trained a new model where we added penalty for the model whenever the model returns a prediction that already exists in the input tokens. 


### New Loss functions

The major change in the new model is the loss functino. Here is the loss function that we defined for the model to penalize duplicates.

```python
class DedupCrossEntropyLoss(nn.Module):
    """
    Custom loss function that combines cross-entropy loss with a penalty term
    for tokens that are repeated from the input sequence.
    """
    def __init__(self, penalty_weight = 1.0):
        super(DedupCrossEntropyLoss, self).__init__()
        self.penalty_weight = penalty_weight
        self.ce_loss = nn.CrossEntropyLoss()
        logger.info(f"Penalty weight: {penalty_weight}")

    def forward(self, logits: torch.Tensor, labels: torch.Tensor, input_tokens: torch.Tensor):
        # Calculate the standard cross-entropy loss
        ce_loss = self.ce_loss(logits, labels)
        if self.penalty_weight == 0:
            return ce_loss
        token_output = torch.argmax(logits, dim=1)
        duplicated_masks = torch.eq(input_tokens, token_output.unsqueeze(-1)).any(dim=-1).float()
        penalty = duplicated_masks * self.penalty_weight
        loss = ce_loss + penalty.mean()
        return loss
```

### Model training diff

I played around with a few different learning rates and penalty factor and finally landed on the numbers below. We show the diff from the original model trained in previous post below:

```diff
-> % git --no-pager diff training_config.yaml 
diff --git a/movielens-ntp/training_config.yaml b/movielens-ntp/training_config.yaml
index 235e7b5..9dff23a 100644
--- a/movielens-ntp/training_config.yaml
+++ b/movielens-ntp/training_config.yaml
@@ -2,12 +2,13 @@ trainer_config:
   data_dir: ./data/ml-1m
   model_dir: ./models
   batch_size: 512
-  starting_learning_rate: 0.0005
+  starting_learning_rate: 0.0008
   learning_rate_decay: 0.95
   device: cuda
   num_epochs: 1000
   validation_fraction: 0.15
   tensorboard_dir: ./runs
+  penalize_duplicates_factor: 0.2
 
 movie_transformer_config:
   context_window_size: 5

```

Finally, we trained the model:

```shell
python model_train.py --config_file=./training_config.yaml --penalize-duplicates
```

Next, we will evaluate our newly trained model. Let's define a few simple functions first!

In [35]:
import torch
import torch.nn.functional as F
from model_train import run_model_training, load_config, get_model_config
from data import MovieLensSequenceDataset
from torch.utils.data import DataLoader
from tbparse import SummaryReader
import plotly.express as px
from eval import (
    get_model_predictions,
    get_model_predictions,
    calculate_metrics,
    calculate_relevance,
)
import pandas as pd
import numpy as np
import torch.nn as nn

In [22]:
from movielens_transformer import MovieLensTransformer


def load_model_artifacts(model_file: str, config_file: str):
    config_file = "./training_config.yaml"
    config = load_config(config_file)
    sequence_length = config["movie_transformer_config"]["context_window_size"]
    batch_size = config["trainer_config"]["batch_size"]
    valid_dataset = MovieLensSequenceDataset(
        movies_file="./data/ml-1m/movies.dat",
        users_file="./data/ml-1m/users.dat",
        ratings_file="./data/ml-1m/ratings.dat",
        sequence_length=sequence_length,
        window_size=1,
        is_validation=True,
    )
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    with open(model_file, "rb") as f:
        model_state_dict = torch.load(f, weights_only=True)
        model_config = get_model_config(config, valid_dataset)
    trained_model = MovieLensTransformer(model_config)
    trained_model.load_state_dict(model_state_dict)
    return config, trained_model, valid_dataloader


def predict_next_movie(model, movie_ids, user_ids):
    model.eval()
    logits = model(movie_ids=movie_ids, user_ids=user_ids)
    probabilities = F.softmax(logits, dim=-1)
    predicted_movie_ids = torch.argmax(probabilities, dim=-1)
    return predicted_movie_ids

In [23]:
def read_tensorboard_logs(log_file: str):
    reader = SummaryReader(log_file)
    metrics = reader.scalars
    return metrics

In [24]:
models = {
    "baseline": (
        "./models/model_1000e_512_32_32_4_4.pth",
        "./models/events.out.tfevents.1724703053.kapilsh-dev-big.514232.0",
    ),
    "with_duplication_penalty": (
        "./models/model_1000e_512_32_32_4_4_w_pen.pth",
        "./models/events.out.tfevents.1725245582.kapilsh-dev-big.1079645.0",
    ),
}

In [32]:
def populate_ranking_metrics(
    model_name, model, valid_dataloader, sequence_length, k_values
):
    metrics = {}

    for k in k_values:
        # we would torch.cat these later
        model_relevances = []
        model_scores = []

        for batch in valid_dataloader:
            (
                movie_id_tokens,
                rating_id_tokens,
                user_id_tokens,
                output_movie_id_tokens,
                output_rating_id_tokens,
            ) = batch
            model_predictions = get_model_predictions(
                model, movie_id_tokens, user_id_tokens, n=k
            )
            model_relevance = calculate_relevance(
                model_predictions.predictions, output_movie_id_tokens
            )
            model_relevances.append(model_relevance)
            model_scores.append(model_predictions.scores)

        model_relevances_tensor = torch.cat(model_relevances)
        model_scores_tensor = torch.cat(model_scores)

        # Calculate the metrics
        model_metrics = calculate_metrics(model_relevances_tensor, model_scores_tensor)

        metrics[k] = model_metrics

    metrics_df = pd.DataFrame(
        [
            {
                "k": k,
                f"MRR_{model_name}": v.MRR,
                f"MAP_{model_name}": v.MAP,
                f"NDCG_{model_name}": v.NDCG,
            }
            for k, v in metrics.items()
        ]
    )
    return metrics_df

In [33]:
k_values = [3, 5, 10]

for model_name, (model_file, tb_file) in models.items():
    print(f"======= Model: {model_name} =======")
    config, model, valid_dataloader = load_model_artifacts(
        model_file, "./training_config.yaml"
    )
    metrics = read_tensorboard_logs(tb_file)
    metrics_by_run = metrics.groupby(["step", "tag"]).mean().unstack()["value"]
    metrics_by_run.columns.name = None
    fig = px.line(metrics_by_run, title=f"Model: {model_name}")
    fig.show()

    print("======= Ranking Metrics =======")
    ranking_metrics = populate_ranking_metrics(
        model_name,
        model,
        valid_dataloader,
        config["movie_transformer_config"]["context_window_size"],
        k_values,
    )
    print(ranking_metrics)

2024-09-02 06:36:38.625 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2024-09-02 06:36:38.625 | INFO     | data:read_movielens_data:12 - Reading data from files




2024-09-02 06:36:40.441 | INFO     | data:_add_tokens:140 - Adding tokens to data
2024-09-02 06:36:40.477 | INFO     | data:_generate_sequences:159 - Generating sequences
2024-09-02 06:36:41.153 | INFO     | data:__init__:110 - Train data length: 883906
2024-09-02 06:36:41.153 | INFO     | data:__init__:111 - Validation data length: 98183
2024-09-02 06:36:41.168 | INFO     | model_train:get_model_config:99 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 




2024-09-02 06:36:59.639 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2024-09-02 06:36:59.639 | INFO     | data:read_movielens_data:12 - Reading data from files


    k  MRR_baseline  MAP_baseline  NDCG_baseline
0   3      0.074770      0.115040       0.085057
1   5      0.087306      0.169510       0.107540
2  10      0.100790      0.269741       0.140009


2024-09-02 06:37:01.362 | INFO     | data:_add_tokens:140 - Adding tokens to data
2024-09-02 06:37:01.400 | INFO     | data:_generate_sequences:159 - Generating sequences
2024-09-02 06:37:02.076 | INFO     | data:__init__:110 - Train data length: 884058
2024-09-02 06:37:02.077 | INFO     | data:__init__:111 - Validation data length: 98031
2024-09-02 06:37:02.099 | INFO     | model_train:get_model_config:99 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


    k  MRR_with_duplication_penalty  MAP_with_duplication_penalty  \
0   3                      0.071360                      0.109241   
1   5                      0.083089                      0.161582   
2  10                      0.095927                      0.258459   

   NDCG_with_duplication_penalty  
0                       0.081040  
1                       0.102399  
2                       0.133622  


Based on these results, the new model seems worse! Let's check if the dedup logic is actually working. We will check how many predicted values are duplicated in baseline vs duplication_penalty model to root-cause.

In [39]:
def get_duplicated_movies(k: int, model: nn.Module, valid_dataloader: DataLoader):
    duplicated_movies = []
    for i, batch in enumerate(valid_dataloader):
        movie_id_tokens, rating_ids, user_id_tokens, movie_targets, rating_targets = (
            batch
        )
        with torch.no_grad():
            # batch x num_tokens
            output = model(movie_id_tokens, user_id_tokens)
        output_probabilites = F.softmax(output, dim=-1)
        _, top_tokens = output_probabilites.topk(k, dim=-1)

        for i in range(movie_id_tokens.shape[0]):
            input_tokens = movie_id_tokens[i]
            output_tokens = top_tokens[i]
            concat_tensor, counts = torch.cat([input_tokens, output_tokens]).unique(
                return_counts=True
            )
            intersection = concat_tensor[torch.where(counts.gt(1))]
            if intersection.shape[0] > 0:
                duplicated_movies.append(intersection)

    return torch.cat(duplicated_movies)

In [43]:
k_values = [3, 5, 10]

for model_name, (model_file, tb_file) in models.items():
    config, model, valid_dataloader = load_model_artifacts(
        model_file, "./training_config.yaml"
    )

    for k in k_values:
        duplicated_movies = get_duplicated_movies(k, model, valid_dataloader)
        average_duplications = duplicated_movies.shape[0] / (
            k * len(valid_dataloader.dataset)
        )
        print(f"[{model_name}] Average Duplications @ {k}: ", average_duplications)

2024-09-02 06:55:15.188 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2024-09-02 06:55:15.189 | INFO     | data:read_movielens_data:12 - Reading data from files
2024-09-02 06:55:16.938 | INFO     | data:_add_tokens:140 - Adding tokens to data
2024-09-02 06:55:16.974 | INFO     | data:_generate_sequences:159 - Generating sequences
2024-09-02 06:55:17.646 | INFO     | data:__init__:110 - Train data length: 884093
2024-09-02 06:55:17.647 | INFO     | data:__init__:111 - Validation data length: 97996
2024-09-02 06:55:17.662 | INFO     | model_train:get_model_config:99 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


[baseline] Average Duplications @ 3:  0.14658081282229207
[baseline] Average Duplications @ 5:  0.13649332625821461


2024-09-02 06:55:26.988 | INFO     | data:__init__:89 - Creating MovieLensSequenceDataset with validation set: %s
2024-09-02 06:55:26.989 | INFO     | data:read_movielens_data:12 - Reading data from files


[baseline] Average Duplications @ 10:  0.11793134413649536


2024-09-02 06:55:28.595 | INFO     | data:_add_tokens:140 - Adding tokens to data
2024-09-02 06:55:28.632 | INFO     | data:_generate_sequences:159 - Generating sequences
2024-09-02 06:55:29.404 | INFO     | data:__init__:110 - Train data length: 883429
2024-09-02 06:55:29.404 | INFO     | data:__init__:111 - Validation data length: 98660
2024-09-02 06:55:29.411 | INFO     | model_train:get_model_config:99 - Model config:
MovieLensTransformerConfig(movie_transformer_config=TransformerConfig(vocab_size=3885, context_window_size=5, embedding_dimension=32, num_layers=4, num_heads=4, dropout_embeddings=0.1, dropout_attention=0.1, dropout_residual=0.1, layer_norm_epsilon=1e-05), user_embedding_dimension=32, num_users=6040, interaction_mlp_hidden_sizes=[16]) 


[with_duplication_penalty] Average Duplications @ 3:  0.14265153050881815
[with_duplication_penalty] Average Duplications @ 5:  0.1335069937157916
[with_duplication_penalty] Average Duplications @ 10:  0.11543077234948307


We do see fewer repetitions. So, our penalty term is working. We might need to to tweak some hyperparameters or let the model train longer. At this point, we might just push ahead make some other changes to the model to improve training efficiency for learning and performance. 

For example,

- Learning rate scheduler: We slightly increased learning rate in penalty and that made model learn faster. We could consider a learning rate scheduler to have higher learning rate initially and decay it as the model trains
- `torch.compile`: torch.compile should significantly improve our model training performance

I will checkpoint the repo now and we will work on that next.