# Training Skip Gram for Item2Vec

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import json
import os
import sys
from typing import Any

import lightning as L
import mlflow
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from loguru import logger
from pydantic import BaseModel, PrivateAttr
from torch.utils.data import DataLoader

load_dotenv()

sys.path.insert(0, "..")

from src.eval import log_classification_metrics, visualize_training
from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset
from src.skipgram.model import SkipGram
from src.skipgram.trainer import LitSkipGram
from src.train_utils import MetricLogCallback, MLflowLogCallback
from src.viz import blueq_colors

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    _mlf_logger: Any = PrivateAttr()
    experiment_name: str = "FSDS RecSys - L7 - Model Serving"
    run_name: str = "001-first-run-mlflow"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    top_K: int = 100
    top_k: int = 10

    max_epochs: int = 1000
    batch_size: int = 128

    num_negative_samples: int = 2
    window_size: int = 1

    embedding_dim: int = 128
    early_stopping_patience: int = 5
    learning_rate: float = 0.01
    l2_reg: float = 1e-5

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(experiment_name=self.experiment_name, run_name=self.run_name, tracking_uri=mlflow_uri, log_model=True)

            # mlflow.set_experiment(self.experiment_name)
            # mlflow.start_run(run_name=self.run_name)

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-01 18:31:37.423[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m36[0m - [1mSetting up MLflow experiment FSDS RecSys - L7 - Model Serving - run 001-first-run-mlflow...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "FSDS RecSys - L7 - Model Serving",
  "run_name": "001-first-run-mlflow",
  "notebook_persist_dp": "/Users/quy.dinh/frostmourne/fsds-recsys/chapters/l7/notebooks/data/001-first-run-mlflow",
  "random_seed": 41,
  "device": null,
  "top_K": 100,
  "top_k": 10,
  "max_epochs": 1000,
  "batch_size": 128,
  "num_negative_samples": 2,
  "window_size": 1,
  "embedding_dim": 128,
  "early_stopping_patience": 5,
  "learning_rate": 0.01,
  "l2_reg": 0.00001
}


# Implementation

In [4]:
def init_model(n_items, embedding_dim):
    model = SkipGram(n_items, embedding_dim)
    return model

# Test implementation

In [5]:
n_items = 1000
window_size = 1
negative_samples = 2
batch_size = 2

model = init_model(n_items, args.embedding_dim)

# Example inputs
target_items = torch.tensor([1, 2, 3])  # Target item IDs
context_items = torch.tensor([10, 20, 30])  # Context item IDs
labels = torch.tensor([1, 0, 1])  # Positive or negative context pairs

# Forward pass
predictions = model(target_items, context_items)
print(predictions)

tensor([0.4960, 0.5066, 0.5003], grad_fn=<SigmoidBackward0>)


In [6]:
# Mock dataset
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

sequences_fp = "sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")

dataset = SkipGramDataset(
    sequences_fp, window_size=window_size, negative_samples=negative_samples
)
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
    num_workers=0,
)

[32m2024-10-01 18:31:37.505[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [7]:
for batch_input in train_loader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 1,  9, 11,  0,  2,  7,  8,  9,  7]), 'labels': tensor([1., 0., 0., 1., 1., 0., 0., 0., 0.])}


In [8]:
# model
lit_model = LitSkipGram(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 128 K  | train
----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.513     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/quy.dinh/frostmourne/fsds-recsys/chapters/l7/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/quy.dinh/frostmourne/fsds-recsys/chapters/l7/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.


# Prep data

In [9]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"
idm = IDMapper().load("../data/idm.json")

In [10]:
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    dataset.interacted,
    dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

train_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=dataset.collate_fn,
    num_workers=4,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
    num_workers=4,
    persistent_workers=True,
)

[32m2024-10-01 18:31:38.201[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-10-01 18:31:38.816[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [11]:
assert dataset.id_to_idx == idm.item_to_index, "ID Mappings are not matched!"

# Train

In [12]:
n_items = len(dataset.items)

## Overfit 1 batch

In [13]:
# Still need to load the small 1 batch sample due to negative sampling which would cause the overfit_batches of Lightning to give new data every epochs
batch_sequences_fp = "../data/batch_item_sequence.jsonl"
batch_size = 1
window_size = 1
num_negative_samples = 2

batch_dataset = SkipGramDataset(
    batch_sequences_fp, window_size=window_size, negative_samples=num_negative_samples
)
batch_train_loader = DataLoader(
    batch_dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
)

[32m2024-10-01 18:31:38.902[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [14]:
i = 0
for batch_input in batch_train_loader:
    print(batch_input)
    i += 1
    if i >= 2:
        break

{'target_items': tensor([0, 0, 0]), 'context_items': tensor([1, 8, 9]), 'labels': tensor([1., 0., 0.])}
{'target_items': tensor([1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 0,  2,  8,  9, 10,  7]), 'labels': tensor([1., 1., 0., 0., 0., 0.])}


In [15]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

# model
model = init_model(len(batch_dataset.items), args.embedding_dim)
lit_model = LitSkipGram(
    model, learning_rate=0.01, l2_reg=0.0, log_dir=args.notebook_persist_dp
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping, checkpoint_callback],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=batch_train_loader,
    val_dataloaders=batch_train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

Checkpoint directory /Users/quy.dinh/frostmourne/fsds-recsys/chapters/l7/notebooks/data/001-first-run-mlflow/checkpoints exists and is not empty.


  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 1.5 K  | train
----------------------------------------------------
1.5 K     Trainable params
0         Non-trainable params
1.5 K     Total params
0.006     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-01 18:31:40.872[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mLogs available at /Users/quy.dinh/frostmourne/fsds-recsys/chapters/l7/notebooks/data/001-first-run-mlflow/logs/overfit/lightning_logs/version_3[0m


In [16]:
%tensorboard --logdir $trainer.log_dir

In [17]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.9997], grad_fn=<SigmoidBackward0>)

## Run with all data

In [18]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

# model
model = init_model(n_items, args.embedding_dim)
lit_model = LitSkipGram(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping],
    accelerator=args.device if args.device else 'auto',
    logger=args._mlf_logger if args.log_to_mlflow else None
)
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Experiment with name FSDS RecSys - L7 - Model Serving not found. Creating it.

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 592 K  | train
----------------------------------------------------
592 K     Trainable params
0         Non-trainable params
592 K     Total params
2.371     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

2024/10/01 18:49:10 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-first-run-mlflow at: http://localhost:5003/#/experiments/1/runs/1920b6493c1a470193e04a70683617d8.
2024/10/01 18:49:10 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5003/#/experiments/1.


AttributeError: 'MlflowClient' object has no attribute 'add_scalar'

In [None]:
model(torch.tensor([0]), torch.tensor([1]))

# Persist model

In [None]:
model.embeddings(torch.tensor(0))[:5]

In [None]:
model_path = f"{args.notebook_persist_dp}/skipgram_model_full.pth"
logger.info(f"Saving model to {model_path}...")
torch.save(model, model_path)

In [None]:
model = torch.load(model_path)

In [None]:
model.embeddings(torch.tensor(0))[:5]

In [None]:
id_mapping_path = f"{args.notebook_persist_dp}/skipgram_id_mapping.json"
logger.info(f"Saving id_mapping to {id_mapping_path}...")
dataset.save_id_mappings(id_mapping_path)

# Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    for params in all_params:
        params_dict = params.dict()
        params_ = {f"{params.__repr_name__()}.{k}": v for k, v in params_dict.items()}
        mlflow.log_params(params_)

    mlflow.end_run()