# Training Skip Gram for Item2Vec

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import json
import os
import sys
from typing import Any

import lightning as L
import mlflow
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel, PrivateAttr
from torch.utils.data import DataLoader

load_dotenv()

sys.path.insert(0, "..")

from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset
from src.skipgram.model import SkipGram
from src.skipgram.trainer import LitSkipGram
from src.skipgram.inference import SkipGramInferenceWrapper
from src.viz import blueq_colors

# Controller

In [None]:
max_epochs = 1

In [3]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    _mlf_logger: Any = PrivateAttr()
    experiment_name: str = "RecSys MVP"
    run_name: str = "000-first-attempt"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = max_epochs
    batch_size: int = 128

    num_negative_samples: int = 2
    window_size: int = 1

    embedding_dim: int = 128
    early_stopping_patience: int = 5
    learning_rate: float = 0.01
    l2_reg: float = 1e-5

    mlf_model_name: str = "item2vec"
    min_roc_auc: float = 0.7

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-10 11:41:16.670[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m36[0m - [1mSetting up MLflow experiment FSDS RecSys - L7 - Model Serving - run 000-first-attempt...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "FSDS RecSys - L7 - Model Serving",
  "run_name": "000-first-attempt",
  "notebook_persist_dp": "/Users/quy.dinh/frostmourne/recsys-mvp/notebooks/data/000-first-attempt",
  "random_seed": 41,
  "device": null,
  "max_epochs": 1,
  "batch_size": 128,
  "num_negative_samples": 2,
  "window_size": 1,
  "embedding_dim": 128,
  "early_stopping_patience": 5,
  "learning_rate": 0.01,
  "l2_reg": 0.00001,
  "mlf_model_name": "item2vec",
  "min_roc_auc": 0.7
}


# Implementation

In [4]:
def init_model(n_items, embedding_dim):
    model = SkipGram(n_items, embedding_dim)
    return model

# Test implementation

In [5]:
n_items = 1000
window_size = 1
negative_samples = 2
batch_size = 2

model = init_model(n_items, args.embedding_dim)

# Example inputs
target_items = torch.tensor([1, 2, 3])  # Target item IDs
context_items = torch.tensor([10, 20, 30])  # Context item IDs
labels = torch.tensor([1, 0, 1])  # Positive or negative context pairs

# Forward pass
predictions = model(target_items, context_items)
print(predictions)

tensor([0.4960, 0.4963, 0.4978], grad_fn=<SigmoidBackward0>)


In [6]:
# Mock dataset
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

sequences_fp = "sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")

dataset = SkipGramDataset(
    sequences_fp, window_size=window_size, negative_samples=negative_samples
)
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
    num_workers=0,
)

[32m2024-10-10 11:41:16.757[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [7]:
for batch_input in train_loader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 1,  9,  7,  0,  2,  9, 10,  7, 10]), 'labels': tensor([1., 0., 0., 1., 1., 0., 0., 0., 0.])}


In [8]:
# model
lit_model = LitSkipGram(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 128 K  | train
----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.513     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/quy.dinh/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/quy.dinh/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.


# Prep data

In [9]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"
idm = IDMapper().load("../data/idm.json")

In [10]:
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    dataset.interacted,
    dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

train_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=dataset.collate_fn,
    # TODO: Understand and make use of this parallel workers to make the model train faster
    # num_workers=4,
    # persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2024-10-10 11:41:17.446[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-10-10 11:41:18.064[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [11]:
assert dataset.id_to_idx == idm.item_to_index, "ID Mappings are not matched!"

# Train

In [12]:
n_items = len(dataset.items)

## Overfit 1 batch

In [13]:
#papermill_description=overfit-1-batch
# Still need to load the small 1 batch sample due to negative sampling which would cause the overfit_batches of Lightning to give new data every epochs
batch_sequences_fp = "../data/batch_item_sequence.jsonl"
batch_size = 1
window_size = 1
num_negative_samples = 2

batch_dataset = SkipGramDataset(
    batch_sequences_fp, window_size=window_size, negative_samples=num_negative_samples
)
batch_train_loader = DataLoader(
    batch_dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
)

[32m2024-10-10 11:41:18.153[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [14]:
i = 0
for batch_input in batch_train_loader:
    print(batch_input)
    i += 1
    if i >= 2:
        break

{'target_items': tensor([0, 0, 0]), 'context_items': tensor([1, 7, 8]), 'labels': tensor([1., 0., 0.])}
{'target_items': tensor([1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 0,  2,  9,  8, 10,  9]), 'labels': tensor([1., 1., 0., 0., 0., 0.])}


In [15]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

# model
model = init_model(len(batch_dataset.items), args.embedding_dim)
lit_model = LitSkipGram(
    model, learning_rate=0.01, l2_reg=0.0, log_dir=args.notebook_persist_dp
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=batch_train_loader,
    val_dataloaders=batch_train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 1.5 K  | train
----------------------------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-10 11:41:18.979[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mLogs available at /Users/quy.dinh/frostmourne/recsys-mvp/notebooks/data/000-first-attempt/logs/overfit/lightning_logs/version_0[0m


In [16]:
%tensorboard --logdir $trainer.log_dir

In [17]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.9820], grad_fn=<SigmoidBackward0>)

## Run with all data

In [18]:
#papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

# model
model = init_model(n_items, args.embedding_dim)
lit_model = LitSkipGram(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 592 K  | train
----------------------------------------------------
592 K     Trainable params
0         Non-trainable params
592 K     Total params
2.371     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.
2024/10/10 11:42:16 INFO mlflow.tracking._tracking_service.client: 🏃 View run 000-first-attempt at: http://localhost:5002/#/experiments/1/runs/e72b532535be4763a191c6da5591040c.
2024/10/10 11:42:16 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/1.
[32m2024-10-10 11:42:16.883[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mLogs available at None[0m


In [19]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.5856], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [20]:
best_trainer = LitSkipGram.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    skipgram_model=init_model(n_items, args.embedding_dim),
)

In [21]:
best_model = best_trainer.skipgram_model

In [22]:
best_model.to("cpu").embeddings(torch.tensor(0))[:5]

tensor([ 0.0724, -0.0767, -0.2191, -0.0854,  0.1054], grad_fn=<SliceBackward0>)

### Persist id mapping

In [23]:
# Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
id_mapping_filename = "skipgram_id_mapping.json"
id_mapping_path = f"{args.notebook_persist_dp}/{id_mapping_filename}"
logger.info(f"Saving id_mapping to {id_mapping_path}...")
dataset.save_id_mappings(id_mapping_path)

if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, id_mapping_path)

[32m2024-10-10 11:42:17.198[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSaving id_mapping to /Users/quy.dinh/frostmourne/recsys-mvp/notebooks/data/000-first-attempt/skipgram_id_mapping.json...[0m


### Wrap inference function and register best checkpoint as MLflow model

In [24]:
inferrer = SkipGramInferenceWrapper(best_model)

In [25]:
sample_input = {"item_1_ids": [dataset.idx_to_id[0]], "item_2_ids": [dataset.idx_to_id[1]]}
sample_output = inferrer.infer([0], [1])
sample_output

array([0.5856392], dtype=float32)

In [26]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"id_mapping": mlflow.get_artifact_uri(id_mapping_filename)},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Successfully registered model 'item2vec'.
2024/10/10 11:42:20 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: item2vec, version 1
Created version '1' of model 'item2vec'.


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

2024/10/10 11:42:21 INFO mlflow.tracking._tracking_service.client: 🏃 View run 000-first-attempt at: http://localhost:5002/#/experiments/1/runs/e72b532535be4763a191c6da5591040c.
2024/10/10 11:42:21 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/1.


# Set the newly trained model as champion

In [27]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-10-10 11:42:21.597[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [28]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = {
                f"{params.__repr_name__()}.{k}": v for k, v in params_dict.items()
            }
            mlflow.log_params(params_)

2024/10/10 11:42:26 INFO mlflow.tracking._tracking_service.client: 🏃 View run 000-first-attempt at: http://localhost:5002/#/experiments/1/runs/e72b532535be4763a191c6da5591040c.
2024/10/10 11:42:26 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/1.
