# Training Skip Gram for Item2Vec

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import json
import os
import sys

import mlflow
import pandas as pd
import torch
from dotenv import load_dotenv
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from loguru import logger
from pydantic import BaseModel
from torch.utils.data import DataLoader

load_dotenv()

sys.path.insert(0, "..")

from src.eval import log_classification_metrics, visualize_training
from src.id_mapper import IDMapper
from src.skipgram.dataset import SkipGramDataset
from src.skipgram.model import SkipGram
from src.skipgram.trainer import LitSkipGram
from src.train_utils import MetricLogCallback, MLflowLogCallback
from src.viz import blueq_colors

# Controller

In [3]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = False
    experiment_name: str = "FSDS RecSys - L6 - Scale training"
    run_name: str = "004-test-lightning"
    notebook_persist_dp: str = None
    random_seed: int = 41

    top_K: int = 100
    top_k: int = 10

    max_epochs: int = 1000
    batch_size: int = 128

    num_negative_samples: int = 2
    window_size: int = 1

    embedding_dim: int = 128
    early_stopping_patience: int = 5
    learning_rate: float = 0.01
    l2_reg: float = 1e-5

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not os.environ.get("MLFLOW_TRACKING_URI"):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )

            mlflow.set_experiment(self.experiment_name)
            mlflow.start_run(run_name=self.run_name)

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

{
  "testing": false,
  "log_to_mlflow": false,
  "experiment_name": "FSDS RecSys - L6 - Scale training",
  "run_name": "004-test-lightning",
  "notebook_persist_dp": "/Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/notebooks/data/004-test-lightning",
  "random_seed": 41,
  "top_K": 100,
  "top_k": 10,
  "max_epochs": 1000,
  "batch_size": 128,
  "num_negative_samples": 2,
  "window_size": 1,
  "embedding_dim": 128,
  "early_stopping_patience": 5,
  "learning_rate": 0.01,
  "l2_reg": 0.00001
}


# Implementation

In [4]:
def init_model(n_items, embedding_dim):
    model = SkipGram(n_items, embedding_dim)
    return model

# Test implementation

In [5]:
n_items = 1000
window_size = 1
negative_samples = 2
batch_size = 2

model = init_model(n_items, args.embedding_dim)

# Example inputs
target_items = torch.tensor([1, 2, 3])  # Target item IDs
context_items = torch.tensor([10, 20, 30])  # Context item IDs
labels = torch.tensor([1, 0, 1])  # Positive or negative context pairs

# Forward pass
predictions = model(target_items, context_items)
print(predictions)

tensor([0.5037, 0.5028, 0.5031], grad_fn=<SigmoidBackward0>)


In [6]:
# Mock dataset
sequences = [
    ["b", "c", "d", "e", "a"],
    ["f", "b", "k"],
    ["g", "m", "k", "l", "h"],
    ["b", "c", "k"],
    ["j", "i", "c"],
]

sequences_fp = "sequences.jsonl"

with open(sequences_fp, "w") as f:
    for sequence in sequences:
        f.write(json.dumps(sequence) + "\n")

dataset = SkipGramDataset(
    sequences_fp, window_size=window_size, negative_samples=negative_samples
)
train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
    num_workers=0
)

[32m2024-09-30 10:56:39.647[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [7]:
for batch_input in train_loader:
    print(batch_input)
    break

{'target_items': tensor([0, 0, 0, 1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 1, 11, 10,  0,  2, 10,  7,  8,  9]), 'labels': tensor([1., 0., 0., 1., 1., 0., 0., 0., 0.])}


In [8]:
# model
lit_model = LitSkipGram(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2)
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 128 K  | train
----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.513     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.


# Prep data

In [9]:
sequences_fp = "../data/item_sequence.jsonl"
val_sequences_fp = "../data/val_item_sequence.jsonl"
idm = IDMapper().load("../data/idm.json")

In [10]:
dataset = SkipGramDataset(
    sequences_fp,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)
val_dataset = SkipGramDataset(
    val_sequences_fp,
    dataset.interacted,
    dataset.item_freq,
    window_size=args.window_size,
    negative_samples=args.num_negative_samples,
    id_to_idx=idm.item_to_index,
)

train_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=dataset.collate_fn,
    num_workers=4,
    persistent_workers=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
    num_workers=4,
    persistent_workers=True
)

[32m2024-09-30 10:56:40.205[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

[32m2024-09-30 10:56:40.710[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [11]:
assert dataset.id_to_idx == idm.item_to_index, "ID Mappings are not matched!"

# Train

In [12]:
n_items = len(dataset.items)

## Overfit 1 batch

In [13]:
# Still need to load the small 1 batch sample due to negative sampling which would cause the overfit_batches of Lightning to give new data every epochs
batch_sequences_fp = "../data/batch_item_sequence.jsonl"
batch_size = 1
window_size = 1
num_negative_samples = 2

batch_dataset = SkipGramDataset(
    batch_sequences_fp, window_size=window_size, negative_samples=num_negative_samples
)
batch_train_loader = DataLoader(
    batch_dataset,
    batch_size=batch_size,
    drop_last=False,
    collate_fn=dataset.collate_fn,
)

[32m2024-09-30 10:56:40.779[0m | [1mINFO    [0m | [36msrc.skipgram.dataset[0m:[36m__init__[0m:[36m58[0m - [1mProcessing sequences to build interaction data...[0m


Building interactions: 0it [00:00, ?it/s]

In [14]:
i = 0
for batch_input in batch_train_loader:
    print(batch_input)
    i += 1
    if i >= 2:
        break

{'target_items': tensor([0, 0, 0]), 'context_items': tensor([ 1, 12,  9]), 'labels': tensor([1., 0., 0.])}
{'target_items': tensor([1, 1, 1, 1, 1, 1]), 'context_items': tensor([ 0,  2, 10,  8, 10, 12]), 'labels': tensor([1., 1., 0., 0., 0., 0.])}


In [15]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min',
    verbose=False
)

# model
model = init_model(len(batch_dataset.items), args.embedding_dim)
lit_model = LitSkipGram(model, learning_rate=0.01, l2_reg=0.0, log_dir=args.notebook_persist_dp)

log_dir = f"{args.notebook_persist_dp}/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping]
)
trainer.fit(
    model=lit_model,
    train_dataloaders=batch_train_loader,
    val_dataloaders=batch_train_loader
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 1.9 K  | train
----------------------------------------------------
1.9 K     Trainable params
0         Non-trainable params
1.9 K     Total params
0.008     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-09-30 10:56:42.328[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mLogs available at /Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/notebooks/data/004-test-lightning/overfit/lightning_logs/version_3[0m


In [16]:
%tensorboard --logdir $trainer.log_dir

In [17]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.9997], grad_fn=<SigmoidBackward0>)

## Run with all data

In [18]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=args.early_stopping_patience,
    mode='min',
    verbose=False
)

# model
model = init_model(n_items, args.embedding_dim)
lit_model = LitSkipGram(model, learning_rate=args.learning_rate, l2_reg=args.l2_reg, log_dir=args.notebook_persist_dp)

log_dir = f"{args.notebook_persist_dp}/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping]
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 601 K  | train
----------------------------------------------------
601 K     Trainable params
0         Non-trainable params
601 K     Total params
2.405     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-09-30 11:11:05.827[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mLogs available at /Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/notebooks/data/004-test-lightning/run/lightning_logs/version_3[0m


In [19]:
model(torch.tensor([0]), torch.tensor([1]))

tensor([0.5337], grad_fn=<SigmoidBackward0>)

# Persist model

In [20]:
model.embeddings(torch.tensor(0))[:5]

tensor([ 0.1420, -0.0175, -0.0848, -0.0265,  0.0492], grad_fn=<SliceBackward0>)

In [21]:
model_path = f"{args.notebook_persist_dp}/skipgram_model_full.pth"
logger.info(f"Saving model to {model_path}...")
torch.save(model, model_path)

[32m2024-09-30 11:11:06.022[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mSaving model to /Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/notebooks/data/004-test-lightning/skipgram_model_full.pth...[0m


In [22]:
model = torch.load(model_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [23]:
model.embeddings(torch.tensor(0))[:5]

tensor([ 0.1420, -0.0175, -0.0848, -0.0265,  0.0492], grad_fn=<SliceBackward0>)

In [24]:
id_mapping_path = f"{args.notebook_persist_dp}/skipgram_id_mapping.json"
logger.info(f"Saving id_mapping to {id_mapping_path}...")
dataset.save_id_mappings(id_mapping_path)

[32m2024-09-30 11:11:06.091[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mSaving id_mapping to /Users/dvq/frostmourne/fsds/fsds-recsys/chapters/l6/notebooks/data/004-test-lightning/skipgram_id_mapping.json...[0m
