# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import json
import os
import sys
from typing import Any

import lightning as L
import mlflow
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel, PrivateAttr
from torch.utils.data import DataLoader

load_dotenv()

sys.path.insert(0, "..")

from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset
from src.skipgram.model import SkipGram
from src.skipgram.trainer import LitSkipGram
from src.skipgram.inference import SkipGramInferenceWrapper



# Controller

In [3]:
max_epochs = 1

In [None]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    _mlf_logger: Any = PrivateAttr()
    experiment_name: str = "recsys"
    run_name: str = "001-item2vec"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    batch_size: int = 128

    num_negative_samples: int = 2
    window_size: int = 1

    embedding_dim: int = 128
    early_stopping_patience: int = 5
    learning_rate: float = 0.01
    l2_reg: float = 1e-5

    mlf_model_name: str = "item2vec"
    min_roc_auc: float = 0.7

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2025-11-11 12:33:04.739[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m37[0m - [1mSetting up MLflow experiment recsys - run 000...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "recsys",
  "run_name": "000",
  "notebook_persist_dp": "/mnt/d/projects/recsys/notebooks/data/000",
  "random_seed": 41,
  "device": null,
  "max_epochs": 1,
  "batch_size": 128,
  "num_negative_samples": 2,
  "window_size": 1,
  "embedding_dim": 128,
  "early_stopping_patience": 5,
  "learning_rate": 0.01,
  "l2_reg": 0.00001,
  "mlf_model_name": "item2vec",
  "min_roc_auc": 0.7
}


# Implement

In [5]:
def init_model(n_items, embedding_dim):
    model = SkipGram(n_items, embedding_dim)
    return model

# Test implementation

In [6]:
n_items = 1000
window_size = 1
negative_samples = 2
batch_size = 2

model = init_model(n_items, args.embedding_dim)

# Example inputs
target_items = torch.tensor([1, 2, 3])  # Target item IDs
context_items = torch.tensor([10, 20, 30])  # Context item IDs
labels = torch.tensor([1, 0, 1])  # Positive or negative context pairs

# Forward pass
predictions = model(target_items, context_items)
print(predictions)

tensor([0.5028, 0.4947, 0.5017], grad_fn=<SigmoidBackward0>)


In [7]:
# Mock dataset
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

sequences_fp = "sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")

dataset = SkipGramDataset(
    sequences_fp, window_size=window_size, negative_samples=negative_samples
)
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
    num_workers=0,
)

[32m2025-11-11 12:33:53.561[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [8]:
i = 0
for batch in train_loader:
    print(batch)
    i += 1
    if i >= 2:
        break

{'target_items': tensor([0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 1, 10,  9,  0,  2,  7,  8,  5,  9]), 'labels': tensor([1., 0., 0., 1., 1., 0., 0., 0., 0.])}
{'target_items': tensor([2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]), 'context_items': tensor([ 1,  3,  5,  7,  8,  6,  2,  4, 11,  7,  6,  9]), 'labels': tensor([1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0.])}


In [9]:
for batch_input in train_loader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 1,  9,  7,  0,  2,  8,  7,  7, 10]), 'labels': tensor([1., 0., 0., 1., 1., 0., 0., 0., 0.])}


In [10]:
# model
lit_model = LitSkipGram(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
  return _C._get_float32_matmul_precision()
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 128 K  | train
----------------------------------------------------
128 K     Trainable params
0         Non-trainable

Sanity Checking: |                                                         | 0/? [00:00<?, ?it/s]

/mnt/d/projects/recsys/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/mnt/d/projects/recsys/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |                                                                | 0/? [00:00<?, ?it/s]

Validation: |                                                              | 0/? [00:00<?, ?it/s]

Validation: |                                                              | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


# Prep data

In [11]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"
idm = IDMapper().load("../data/idm.json")

In [12]:
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    dataset.interacted,
    dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=dataset.collate_fn,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)

[32m2025-11-11 12:34:52.734[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2025-11-11 12:34:53.893[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m62[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [13]:
assert dataset.id_to_idx == idm.item_to_index, "ID Mappings are not matched!"

# Train

In [14]:
n_items = len(dataset.items)

#### Training loop

In [31]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

# model

model = init_model(n_items, args.embedding_dim)
lit_model = LitSkipGram(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(model=lit_model, train_dataloaders=dataloader, val_dataloaders=val_dataloader)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores

Checkpoint directory /mnt/d/projects/recsys/notebooks/data/000/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 592 K  | train
----------------------------------------------------
592 K     Trainable params
0         Non-trainable params
592 K     Total params
2.371     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                     | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.



Training: |                                                            | 0/? [00:00<?, ?it/s]

Validation: |                                                          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
[32m2025-11-11 14:31:04.523[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m34[0m - [1mLogs available at None[0m


üèÉ View run 000 at: http://localhost:5002/#/experiments/1/runs/735c41b6b17f44c2a82289dc3b3b2d79
üß™ View experiment at: http://localhost:5002/#/experiments/1


In [32]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.3521], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [33]:
best_trainer = LitSkipGram.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    skipgram_model=init_model(n_items, args.embedding_dim),
)

In [34]:
best_trainer

LitSkipGram(
  (skipgram_model): SkipGram(
    (embeddings): Embedding(4631, 128, padding_idx=4630)
  )
)

In [35]:
best_model = best_trainer.skipgram_model

In [36]:
best_model.to("cpu").embeddings(torch.tensor(0))[:5]

tensor([ 0.0927, -0.1572, -0.2369, -0.0550, -0.0166], grad_fn=<SliceBackward0>)

### Persist id mapping

In [38]:
# Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
id_mapping_filename = "skipgram_id_mapping.json"
id_mapping_path = f"{args.notebook_persist_dp}/{id_mapping_filename}"
logger.info(f"Saving id_mapping to {id_mapping_path}...")
dataset.save_id_mappings(id_mapping_path)

if args.log_to_mlflow:
    # Get the MLflow run ID from the trainer's logger if available, otherwise use the run_id from Args init
    run_id = trainer.logger.run_id if hasattr(trainer, 'logger') and trainer.logger else mlflow.active_run().info.run_id
    mlf_client = mlflow.tracking.MlflowClient() # Get MLflow client
    mlf_client.log_artifact(run_id, id_mapping_path) # Log id_mapping as a separate artifact

[32m2025-11-11 14:33:24.172[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSaving id_mapping to /mnt/d/projects/recsys/notebooks/data/000/skipgram_id_mapping.json...[0m


### Wrap inference function and register best checkpoint as MLflow model

In [39]:
# Chu·∫©n b·ªã input m·∫´u (c√πng format v·ªõi predict)
sample_input = {"item_1_ids": [dataset.idx_to_id[0]], "item_2_ids": [dataset.idx_to_id[1]]}
sample_output = best_model(
    torch.tensor([0]),
    torch.tensor([1])
).detach().numpy()

In [40]:
model_path =checkpoint_callback.best_model_path

In [41]:
checkpoint = torch.load(model_path, map_location="cpu")
checkpoint

{'epoch': 0,
 'global_step': 1283,
 'pytorch-lightning_version': '2.5.6',
 'state_dict': OrderedDict([('skipgram_model.embeddings.weight',
               tensor([[ 9.2688e-02, -1.5724e-01, -2.3695e-01,  ..., -2.2617e-01,
                         2.8986e-01, -4.5999e-01],
                       [-1.3436e-01,  1.0897e-01, -1.1607e-01,  ..., -3.2578e-01,
                         6.6828e-02,  1.3031e-01],
                       [-7.6246e-02, -1.9842e-01,  2.4648e-02,  ..., -1.8524e-01,
                        -2.3274e-02, -6.9895e-02],
                       ...,
                       [-5.2232e-02, -4.9461e-02,  1.5768e-01,  ..., -2.9106e-02,
                         4.8381e-02, -2.8762e-02],
                       [-2.1745e-02, -2.3504e-01, -5.0211e-02,  ...,  1.2026e-01,
                         6.4290e-02, -4.1482e-02],
                       [ 6.8481e-33, -2.6812e-32, -4.4699e-32,  ...,  1.4479e-32,
                        -9.7187e-33, -1.0884e-32]]))]),
 'loops': {'fit_loop': {'state

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    signature = infer_signature(sample_input, sample_output)

    with mlflow.start_run(run_id=run_id, nested=True):
        artifacts = {
            "model_path": checkpoint_callback.best_model_path,
            "id_mapping": mlflow.get_artifact_uri(id_mapping_filename),
        }

        mlflow.pyfunc.log_model(
            artifact_path="inferrer",
            python_model=SkipGramInferenceWrapper(),
            artifacts=artifacts,
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

    print(f"Model logged to MLflow run {run_id}")


2025/11/11 14:34:37 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

  "inputs": {
    "item_1_ids": [
      "0375869.... Alternatively, you can avoid passing input example and pass model signature instead when logging the model. To ensure the input example is valid prior to serving, please try calling `mlflow.models.validate_serving_input` on the model uri and serving input example. A serving input example can be generated from model input example using `mlflow.models.convert_input_example_to_serving_input` function.
Got error: 'hyper_parameters'
Successfully registered model 'item2vec'.
2025/11/11 14:37:28 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: item2vec, version 1
Created version '1' of model 'item2vec'.


üèÉ View run 000 at: http://localhost:5002/#/experiments/1/runs/735c41b6b17f44c2a82289dc3b3b2d79
üß™ View experiment at: http://localhost:5002/#/experiments/1
Model logged to MLflow run 735c41b6b17f44c2a82289dc3b3b2d79


# Set the newly trained model as champion

In [43]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="minh",
        )

[32m2025-11-11 14:37:43.987[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


In [44]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id, nested=True):
        for params in all_params:
            params_dict = params.dict()
            params_ = {
                f"{params.__repr_name__()}.{k}": v for k, v in params_dict.items()
            }
            mlflow.log_params(params_)

/tmp/ipykernel_2538/833195162.py:6: PydanticDeprecatedSince20:

The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.12/migration/



üèÉ View run 000 at: http://localhost:5002/#/experiments/1/runs/735c41b6b17f44c2a82289dc3b3b2d79
üß™ View experiment at: http://localhost:5002/#/experiments/1
