In [32]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
import json
import os
import sys
from typing import Any

import lightning as L
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel, PrivateAttr
from torch.utils.data import DataLoader
import pandas as pd
import mlflow

sys.path.insert(0, "..")

from src.utils.embedding_id_mapper import IDMapper
from src.algo.item2vec.dataset import SkipGramDataset
# from src.algo.item2vec.inference import SkipGramInferenceWrapper
from src.algo.item2vec.model import SkipGram
from src.algo.item2vec.trainer import LitSkipGram

_ = load_dotenv(override = True)

## Controller

In [34]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    _mlf_logger: Any = PrivateAttr()
    experiment_name: str = "Item2vec"
    run_name: str = "001-increse-8k-users-5-negative-samples-dim256"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = 100
    batch_size: int = 128

    num_negative_samples: int = 2
    window_size: int = 1

    embedding_dim: int = 256
    early_stopping_patience: int = 15
    
    learning_rate: float = 0.01
    l2_reg: float = 1e-5

    mlf_model_name: str = "item2vec"
    min_roc_auc: float = 0.7

    train_data_path: str = os.path.abspath("../data_for_ai/interim/train_sample_interactions_16407u.parquet")
    val_data_path: str = os.path.abspath("../data_for_ai/interim/val_sample_interactions_16407u.parquet")
    # test_data_path:str = "../data_for_ai/interim/test_sample_interactions_8000u.parquet"

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2025-06-23 13:11:47.552[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m41[0m - [1mSetting up MLflow experiment Item2vec - run 001-increse-8k-users-5-negative-samples-dim256...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "Item2vec",
  "run_name": "001-increse-8k-users-5-negative-samples-dim256",
  "notebook_persist_dp": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\notebooks\\data\\001-increse-8k-users-5-negative-samples-dim256",
  "random_seed": 41,
  "device": null,
  "max_epochs": 100,
  "batch_size": 128,
  "num_negative_samples": 2,
  "window_size": 1,
  "embedding_dim": 256,
  "early_stopping_patience": 15,
  "learning_rate": 0.01,
  "l2_reg": 0.00001,
  "mlf_model_name": "item2vec",
  "min_roc_auc": 0.7,
  "train_data_path": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\data_for_ai\\interim\\train_sample_interactions_16407u.parquet",
  "val_data_path": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\data_for_ai\\interim\\val_sample_interactions_16407u.parquet"
}


In [35]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
# device = 'cpu'
logger.info(f"Using {device} device")

[32m2025-06-23 13:11:47.929[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mUsing cpu device[0m


In [36]:
def init_model(n_items, embedding_dim, device):
    model = SkipGram(n_items, embedding_dim).to(device)
    return model

# Test implementation

In [37]:
mock_df = pd.DataFrame(
    {
        'timestamp': [1,1,1,2,2,3,3,4,4,4,4,5,5,5],
        'user_id': [101, 101, 103, 104, 103, 105, 107, 108, 109, 110, 111, 112, 113, 114],
        'parent_asin': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
    }
)

In [38]:
mock_df

Unnamed: 0,timestamp,user_id,parent_asin
0,1,101,1
1,1,101,2
2,1,103,3
3,2,104,4
4,2,103,5
5,3,105,6
6,3,107,7
7,4,108,8
8,4,109,9
9,4,110,10


In [39]:
mock_dataset = SkipGramDataset(
    train_interaction_df= mock_df, 
    val_interaction_df= None,
    mode = "train",
    negative_samples= 1,
)

[32m2025-06-23 13:11:49.206[0m | [1mINFO    [0m | [36msrc.algo.item2vec.dataset[0m:[36m__init__[0m:[36m67[0m - [1mProcessing sequences...[0m


In [40]:
mock_dataset.item_id_to_idx

{'1': 0, '2': 1, '3': 2, '5': 3}

In [41]:
mock_dataset.interacted

defaultdict(set, {0: {0, 1}, 1: {0, 1}, 2: {2, 3}, 3: {2, 3}})

In [42]:
for i in mock_dataset:
    print(i)
    

{'target_items': tensor([0, 0]), 'context_items': tensor([1, 3]), 'labels': tensor([1., 0.])}
{'target_items': tensor([1, 1]), 'context_items': tensor([0, 3]), 'labels': tensor([1., 0.])}
{'target_items': tensor([2, 2]), 'context_items': tensor([3, 0]), 'labels': tensor([1., 0.])}
{'target_items': tensor([3, 3]), 'context_items': tensor([2, 0]), 'labels': tensor([1., 0.])}


In [43]:
n_items = 1000
window_size = 1
negative_samples = 2
batch_size = 2

model = init_model(n_items, args.embedding_dim, device = "cpu")

# Example inputs
target_items = torch.tensor([1, 2, 3, 1000])  # Target item IDs
context_items = torch.tensor([10, 20, 30, 40])  # Context item IDs
labels = torch.tensor([1, 0, 1])  # Positive or negative context pairs

predictions = model(target_items, context_items)
print(predictions)

[32m2025-06-23 13:11:51.286[0m | [1mINFO    [0m | [36msrc.algo.item2vec.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 1000, embedding dim 256[0m


tensor([0.4974, 0.5052, 0.4910, 0.5040], grad_fn=<SigmoidBackward0>)


# Prep data

In [44]:
# Read data
train_df = pd.read_parquet(args.train_data_path)
val_df = pd.read_parquet(args.val_data_path)

In [45]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
3194,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B06XKCPK5W,2.0,2012-06-11 16:41:10
3199,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B000CKVOOY,3.0,2012-08-02 02:04:13
3200,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B006GWO5WK,5.0,2012-09-15 16:34:46
3204,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B008LURQ76,5.0,2013-01-03 23:08:45
3208,AEYGPUCRKH7G4VM22FM3VAKSQ23Q,B00AQRUW4Q,4.0,2013-05-06 01:24:39
...,...,...,...,...
40882304,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00KAF5RQ2,5.0,2016-02-22 17:44:10
40882305,AFB4DWWKZBQFS22FAWDEP37EL2FA,B001F6TXME,5.0,2016-02-22 17:44:40
40882306,AFB4DWWKZBQFS22FAWDEP37EL2FA,B007VGGIB6,5.0,2016-02-22 17:45:10
40882307,AFB4DWWKZBQFS22FAWDEP37EL2FA,B00WUID73W,5.0,2016-02-22 17:45:37


In [46]:
val_df.head(1)

Unnamed: 0,user_id,parent_asin,rating,timestamp
4668,AGZE3IYHOEGKUTJZSQCSFSQ4IFFQ,B0B787CN26,5.0,2021-10-27 19:43:57.873


In [47]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp
4668,AGZE3IYHOEGKUTJZSQCSFSQ4IFFQ,B0B787CN26,5.0,2021-10-27 19:43:57.873
10425,AEANO5BIASSZNFWNXBR2ECHCPJQQ,B0002MQGOA,5.0,2021-02-02 14:20:48.424
10426,AEANO5BIASSZNFWNXBR2ECHCPJQQ,B07HZLHPKP,5.0,2021-03-08 13:56:57.795
13265,AHDXCFTV7RS3AM6E2TRPWOG3A33Q,B07QWPVZJY,3.0,2021-12-11 00:34:19.152
14423,AEFHRRLFCZQ3TWNYCBA7UD3NIXCA,B00D96J8IM,1.0,2021-10-17 20:54:19.325
...,...,...,...,...
33760091,AHIIISHZP6YAVVHMDEBLJ5CWZ7ZA,B0BZ62FQ13,3.0,2021-07-16 17:08:55.044
34470392,AFTE3G43QHXWD3DJGDCI2DHEWQJQ,B08DMXDPW5,5.0,2021-01-14 01:48:09.423
35019360,AFENZZDPVUYFVBS47YDOWJCDYBSQ,B09XBT6DS9,4.0,2021-12-05 00:35:40.874
35323250,AFMBZYPDAXT5VO3ME67HW5Q5TAOQ,B097KBF8JK,5.0,2022-02-18 11:32:46.732


In [48]:
idm = IDMapper().load("../data_for_ai/interim/idm_16407u.json")

print(len(idm.item_to_index))

train_dataset = SkipGramDataset(
    train_interaction_df= train_df, 
    val_interaction_df= None,
    mode = "train",
    item_id_to_idx=idm.item_to_index,
    negative_samples= 5
)
val_dataset = SkipGramDataset(
    train_interaction_df= train_df, 
    val_interaction_df= val_df,
    mode = "val",
    item_id_to_idx=idm.item_to_index,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=train_dataset.collate_fn,
    num_workers=4,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=val_dataset.collate_fn,
)


[32m2025-06-23 13:11:55.035[0m | [1mINFO    [0m | [36msrc.algo.item2vec.dataset[0m:[36m__init__[0m:[36m67[0m - [1mProcessing sequences...[0m


4817 items in the dataset
4817


[32m2025-06-23 13:11:56.108[0m | [1mINFO    [0m | [36msrc.algo.item2vec.dataset[0m:[36m__init__[0m:[36m67[0m - [1mProcessing sequences...[0m


In [None]:
print(len(idm.index_to_user), len(idm.user_to_index))
print(len(idm.index_to_item), len(idm.item_to_index))
print(train_dataset.item_id_to_idx)
print()

16407 16407
4817 4817
{'0972683275': 0, '1449410243': 1, 'B000001OM5': 2, 'B00000K2YR': 3, 'B00002EQCW': 4, 'B00004TBLW': 5, 'B00004THD0': 6, 'B00004WCGF': 7, 'B00004Z5D1': 8, 'B00004Z5M1': 9, 'B00004ZCJF': 10, 'B00004ZCJJ': 11, 'B00005N6KG': 12, 'B00005N9D3': 13, 'B00005NIMJ': 14, 'B000063TJY': 15, 'B0000645RH': 16, 'B000065BP9': 17, 'B000065UQA': 18, 'B00006B82A': 19, 'B00006B8K2': 20, 'B00006BBAC': 21, 'B00006HVLW': 22, 'B00006I5J7': 23, 'B00006JN3G': 24, 'B00006JPE1': 25, 'B00006JPEA': 26, 'B00006JQ5O': 27, 'B00007056H': 28, 'B00007AP2O': 29, 'B00007E7C8': 30, 'B00007FGU7': 31, 'B00007KDX6': 32, 'B00007LA0T': 33, 'B00007M1TZ': 34, 'B00008NJEP': 35, 'B00008SCFL': 36, 'B00009KYCN': 37, 'B00009OY9U': 38, 'B0000AI0N1': 39, 'B0000AQR8F': 40, 'B0000BYDKO': 41, 'B0000E1VRT': 42, 'B0000TO0BQ': 43, 'B0000VYJRY': 44, 'B0001FTVEA': 45, 'B0001FTVEK': 46, 'B00026BQJ6': 47, 'B000289DC6': 48, 'B00029MTMQ': 49, 'B0002BEQAM': 50, 'B0002CE0XO': 51, 'B0002GX1XA': 52, 'B0002J1WTC': 53, 'B0002J2B8I': 5

In [None]:
assert train_dataset.item_id_to_idx == idm.item_to_index, "ID Mappings are not matched!"
print(f"Number of items in train dataset: {len(train_dataset.item_id_to_idx)}")
print(f"Number of items in IDMapper: {len(idm.item_to_index)}")
assert train_df["parent_asin"].nunique() == len(idm.item_to_index)  == len(train_dataset.item_id_to_idx), "Mismatch in user mappings"

Number of items in train dataset: 4817
Number of items in IDMapper: 4817


# Train

In [None]:
n_items = len(train_dataset.items)
n_items

4817

#### Optinal: Overfit 1 batch

In [27]:
batch_size = 1  # Need to set to 1 else can not learn
embedding_dim = 32
window_size = 1
num_negative_samples = 2

batch_dataset = SkipGramDataset(
    train_interaction_df= train_df[0:10], 
    val_interaction_df= None,
    mode = "train",
    negative_samples= num_negative_samples,
    window_size= window_size,
)

batch_dataloader = DataLoader(
    batch_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    collate_fn=batch_dataset.collate_fn,
)


model = init_model(len(batch_dataset.items), embedding_dim, device = device)

[32m2025-06-23 12:24:27.959[0m | [1mINFO    [0m | [36msrc.algo.item2vec.dataset[0m:[36m__init__[0m:[36m30[0m - [1mProcessing sequences...[0m
[32m2025-06-23 12:24:27.959[0m | [1mINFO    [0m | [36msrc.algo.item2vec.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 10, embedding dim 32[0m


In [29]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=20, mode="min", verbose=False
)

# model
model = init_model(len(batch_dataset.items), args.embedding_dim, device= device)
lit_model = LitSkipGram(
    model, learning_rate=0.01, l2_reg=0.0, log_dir=args.notebook_persist_dp
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=batch_dataloader,
    val_dataloaders=batch_dataloader,
)
logger.info(f"Logs available at {trainer.log_dir}")

[32m2025-06-23 12:25:04.610[0m | [1mINFO    [0m | [36msrc.algo.item2vec.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 10, embedding dim 256[0m
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 2.8 K  | train
----------------------------------------------------
2.8 K     Trainable params
0         Non-trainable params
2.8 K     Total params
0.011     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[32m2025-06-23 12:25:08.357[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mLogs available at c:\Users\Trieu\OneDrive\Desktop\recsys\real_time_recsys\notebooks\data\001-increse-8k-users-5-negative-samples-dim256\logs\overfit\lightning_logs\version_1[0m


In [30]:
model(torch.tensor([0]), torch.tensor([2]))

tensor([0.5147], grad_fn=<SigmoidBackward0>)

# Run with all data

In [None]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

# model
model = init_model(n_items, args.embedding_dim, device=device)
lit_model = LitSkipGram(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

[32m2025-06-23 12:25:59.066[0m | [1mINFO    [0m | [36msrc.algo.item2vec.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 4817, embedding dim 256[0m
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Experiment with name Item2vec not found. Creating it.

  | Name           | Type     | Params | Mode 
----------------------------------------------------
0 | skipgram_model | SkipGram | 1.2 M  | train
----------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.934     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


In [None]:
best_trainer = LitSkipGram.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    skipgram_model=init_model(n_items, args.embedding_dim, device="cpu"),
)
# best_trainer = LitSkipGram.load_from_checkpoint(
#     "/home/dinhln/Desktop/MLOPS/recsys/HM-ScalableRecs/notebooks/data/000-first-attempt/checkpoints/best-checkpoint.ckpt",
#     skipgram_model=init_model(n_items, args.embedding_dim, device="cpu"),
# )

[32m2025-03-20 21:23:37.449[0m | [1mINFO    [0m | [36msrc.skipgram.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 4323, embedding dim 256[0m


In [None]:
best_model = best_trainer.skipgram_model
best_model.eval()

SkipGram(
  (embeddings): Embedding(4324, 256, padding_idx=4323)
)

In [None]:
best_model.to("cpu").embeddings(torch.tensor([8]))[: ,: 10]


tensor([[-0.0070,  0.0178, -0.0015,  0.1067, -0.0469, -0.0087, -0.0016, -0.0003,
          0.0172,  0.1052]], grad_fn=<SliceBackward0>)

In [None]:
best_model(torch.tensor([600]), torch.tensor([603]))

tensor([0.6239], grad_fn=<SigmoidBackward0>)

## Register model

In [None]:
from src.algo.item2vec.inference import SkipGramInferenceWrapper
inferrer = SkipGramInferenceWrapper(best_model)


[33mAdd type hints to the `predict` method to enable data validation and automatic signature inference during model logging. Check https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel for more details.[0m



In [None]:
sample_input = {
    "item_1_ids": [train_dataset.item_idx_to_id[0]],
    "item_2_ids": [train_dataset.item_idx_to_id[1]],
}
sample_output = inferrer.infer([0], [1])
sample_output

array([0.5327789], dtype=float32)

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"id_mapping": "C:/Users/Trieu/OneDrive/Desktop/recsys/real_time_recsys/data_for_ai/interim/idm_16407u.json"},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

2025/03/20 21:23:37 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'item2vec' already exists. Creating a new version of this model...
2025/03/20 21:23:42 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: item2vec, version 7


[32m2025-03-20 21:23:41.865[0m | [1mINFO    [0m | [36msrc.skipgram.inference[0m:[36mload_context[0m:[36m19[0m - [1mLength of item: 4323[0m
[32m2025-03-20 21:23:41.866[0m | [1mINFO    [0m | [36msrc.skipgram.inference[0m:[36mload_context[0m:[36m20[0m - [1mLength of user: 9495[0m
🏃 View run 001-increse-8k-users-5-negative-samples-dim256 at: http://localhost:5002/#/experiments/8/runs/19bb6989961548e084636529c65dc214
🧪 View experiment at: http://localhost:5002/#/experiments/8


Created version '7' of model 'item2vec'.


In [None]:
# # Set the new model as champion
# mlf_client = trainer.logger.experiment
# if args.log_to_mlflow:
#     val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
#         "val_roc_auc"
#     ]

#     if val_roc_auc > args.min_roc_auc:
#         logger.info(f"Aliasing the new model as champion...")
#         model_version = (
#             mlf_client.get_registered_model(args.mlf_model_name)
#             .latest_versions[0]
#             .version
#         )

#         mlf_client.set_registered_model_alias(
#             name=args.mlf_model_name, alias="champion", version=model_version
#         )

#         mlf_client.set_model_version_tag(
#             name=args.mlf_model_name,
#             version=model_version,
#             key="author",
#             value="dinhnehehe",
#         )

In [None]:
# model_uri = f"models:/{args.mlf_model_name}@champion"
# model = mlflow.pyfunc.load_model(model_uri)

In [None]:
# sample_input = {
#     "item_1_ids": ["1111111111"],
#     "item_2_ids": ["2222"],
# }
# model.predict(sample_input)


In [None]:
# sample_input = {
#     "item_1_ids": [train_dataset.item_idx_to_id[2]],
#     "item_2_ids": [train_dataset.item_idx_to_id[1]],
# }
# model.predict(sample_input)

## Log params and clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.model_dump()
            params_ = {
                f"{params.__repr_name__()}.{k}": v for k, v in params_dict.items()
            }
            mlflow.log_params(params_)

🏃 View run 001-increse-8k-users-5-negative-samples-dim256 at: http://localhost:5002/#/experiments/8/runs/19bb6989961548e084636529c65dc214
🧪 View experiment at: http://localhost:5002/#/experiments/8
