# Nikita Lobachev

Solution was inspired by https://github.com/sb-ai-lab/RePlay

In [1]:
!pip install lightning
!pip install replay-rec

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.5.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m811.0/811.0 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.5.1-py3-none-any.whl (890 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.6/890.6 kB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [144]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter
from replay.utils import get_spark_session
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import Bert4Rec
from replay.models.nn.sequential.bert4rec import (
    Bert4RecPredictionDataset,
    Bert4RecTrainingDataset,
    Bert4RecValidationDataset,
    Bert4RecPredictionBatch,
    Bert4RecModel
)

import pandas as pd

In [145]:
#spark session
spark_session = get_spark_session()



In [146]:
import torch
import numpy as np
import random

def set_global_seed(seed: int) -> None:
    """
    Set global seed for reproducibility.
    """

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_global_seed(42)

# load dataset

In [4]:
!unzip /content/hse-rec-sys-challenge-2024.zip

Archive:  /content/hse-rec-sys-challenge-2024.zip
replace events.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [147]:
interactions = pd.read_csv('/content/events.csv')
user_features = pd.read_csv('/content/user_features.csv')
item_features = pd.read_csv('/content/item_features.csv')

# Split data

In [148]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

# Create "dataset"

In [149]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

In [150]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

In [151]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

In [152]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

In [153]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
        embedding_dim=300,
    )
)

# Sequence Tokeniser

In [154]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [155]:
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

In [156]:
print(tokenizer.query_id_encoder.mapping, tokenizer.query_id_encoder.inverse_mapping)
print(tokenizer.item_id_encoder.mapping, tokenizer.item_id_encoder.inverse_mapping)

{'user_id': {4855: 0, 4065: 1, 3331: 2, 5373: 3, 2032: 4, 5875: 5, 3984: 6, 4062: 7, 5117: 8, 5822: 9, 174: 10, 5188: 11, 595: 12, 2538: 13, 5031: 14, 4765: 15, 1819: 16, 3970: 17, 568: 18, 4007: 19, 2641: 20, 2646: 21, 3839: 22, 3263: 23, 281: 24, 2009: 25, 5836: 26, 1581: 27, 679: 28, 3634: 29, 2401: 30, 2184: 31, 5532: 32, 3638: 33, 4159: 34, 1770: 35, 3754: 36, 637: 37, 1452: 38, 5412: 39, 5345: 40, 3078: 41, 4772: 42, 3484: 43, 1064: 44, 2812: 45, 3120: 46, 4295: 47, 491: 48, 3283: 49, 5595: 50, 622: 51, 4428: 52, 1570: 53, 4561: 54, 3927: 55, 127: 56, 1950: 57, 1877: 58, 2285: 59, 656: 60, 462: 61, 4055: 62, 4477: 63, 2148: 64, 1582: 65, 272: 66, 3556: 67, 883: 68, 5295: 69, 3223: 70, 4070: 71, 3: 72, 5314: 73, 4225: 74, 1341: 75, 5909: 76, 1413: 77, 4463: 78, 3900: 79, 4426: 80, 811: 81, 3491: 82, 5118: 83, 2018: 84, 1308: 85, 4379: 86, 4351: 87, 2995: 88, 3680: 89, 1336: 90, 3758: 91, 1286: 92, 5003: 93, 3574: 94, 1703: 95, 1855: 96, 32: 97, 5901: 98, 5207: 99, 1516: 100, 5457:

# Train model

In [157]:
MAX_SEQ_LEN = 100
BATCH_SIZE = 512
NUM_WORKERS = 4

model = Bert4Rec(
    tensor_schema,
    block_count=2,
    head_count=4,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)
checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)]
)

csv_logger = CSVLogger(save_dir=".logs/train", name="Bert4Rec_example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=Bert4RecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=Bert4RecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /content/.checkpoints exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | Bert4RecModel    | 4.4 M  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
4.4 M     Trainable params
0         Non-trainable params
4.4 M     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (12) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 12: 'recall@10' reached 0.03146 (best 0.03146), saving model to '/content/.checkpoints/epoch=0-step=12-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 12: 'recall@10' reached 0.03146 (best 0.03146), saving model to '/content/.checkpoints/epoch=0-step=12-v1.ckpt' as top 1


k              1        10        20         5
map     0.004139  0.009982  0.011998  0.008019
ndcg    0.004139  0.014892  0.022471  0.010068
recall  0.004139  0.031457  0.061921  0.016391





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 24: 'recall@10' reached 0.03510 (best 0.03510), saving model to '/content/.checkpoints/epoch=1-step=24.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 24: 'recall@10' reached 0.03510 (best 0.03510), saving model to '/content/.checkpoints/epoch=1-step=24.ckpt' as top 1


k              1        10        20         5
map     0.003146  0.009892  0.011839  0.007668
ndcg    0.003146  0.015656  0.022904  0.010193
recall  0.003146  0.035099  0.064073  0.018046





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 36: 'recall@10' reached 0.03974 (best 0.03974), saving model to '/content/.checkpoints/epoch=2-step=36.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 36: 'recall@10' reached 0.03974 (best 0.03974), saving model to '/content/.checkpoints/epoch=2-step=36.ckpt' as top 1


k              1        10        20         5
map     0.003808  0.010865  0.012668  0.008060
ndcg    0.003808  0.017389  0.024099  0.010374
recall  0.003808  0.039735  0.066556  0.017550





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 48: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 48: 'recall@10' was not in top 1


k              1        10        20         5
map     0.004967  0.011679  0.013897  0.009630
ndcg    0.004967  0.017000  0.025079  0.011903
recall  0.004967  0.034934  0.066887  0.018874





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 60: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 60: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003477  0.009836  0.012000  0.007828
ndcg    0.003477  0.015186  0.023200  0.010130
recall  0.003477  0.033278  0.065232  0.017219





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 5, global step 72: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 5, global step 72: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003146  0.009060  0.011390  0.007056
ndcg    0.003146  0.014267  0.022927  0.009245
recall  0.003146  0.031954  0.066556  0.016060





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 6, global step 84: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 6, global step 84: 'recall@10' was not in top 1


k              1        10        20         5
map     0.002152  0.008783  0.010759  0.006703
ndcg    0.002152  0.014234  0.021569  0.009042
recall  0.002152  0.032616  0.061921  0.016225





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 7, global step 96: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 7, global step 96: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00298  0.009137  0.011093  0.007001
ndcg    0.00298  0.014590  0.021792  0.009246
recall  0.00298  0.033113  0.061755  0.016225





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 8, global step 108: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 8, global step 108: 'recall@10' was not in top 1


k              1        10        20         5
map     0.003477  0.010498  0.012537  0.008328
ndcg    0.003477  0.016078  0.023769  0.010827
recall  0.003477  0.034768  0.065728  0.018543





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 9, global step 120: 'recall@10' reached 0.04089 (best 0.04089), saving model to '/content/.checkpoints/epoch=9-step=120.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 9, global step 120: 'recall@10' reached 0.04089 (best 0.04089), saving model to '/content/.checkpoints/epoch=9-step=120.ckpt' as top 1


k              1        10        20         5
map     0.004139  0.012466  0.014374  0.010030
ndcg    0.004139  0.018992  0.025997  0.012941
recall  0.004139  0.040894  0.068709  0.021854





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 10, global step 132: 'recall@10' reached 0.04338 (best 0.04338), saving model to '/content/.checkpoints/epoch=10-step=132.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 10, global step 132: 'recall@10' reached 0.04338 (best 0.04338), saving model to '/content/.checkpoints/epoch=10-step=132.ckpt' as top 1


k              1        10        20         5
map     0.005298  0.013916  0.015885  0.011476
ndcg    0.005298  0.020698  0.027999  0.014705
recall  0.005298  0.043377  0.072517  0.024669





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 11, global step 144: 'recall@10' reached 0.04570 (best 0.04570), saving model to '/content/.checkpoints/epoch=11-step=144.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 11, global step 144: 'recall@10' reached 0.04570 (best 0.04570), saving model to '/content/.checkpoints/epoch=11-step=144.ckpt' as top 1


k              1        10        20         5
map     0.005795  0.015026  0.017411  0.012423
ndcg    0.005795  0.022075  0.030802  0.015606
recall  0.005795  0.045695  0.080298  0.025331





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 12, global step 156: 'recall@10' reached 0.04785 (best 0.04785), saving model to '/content/.checkpoints/epoch=12-step=156.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 12, global step 156: 'recall@10' reached 0.04785 (best 0.04785), saving model to '/content/.checkpoints/epoch=12-step=156.ckpt' as top 1


k              1        10        20         5
map     0.005132  0.014220  0.016504  0.011333
ndcg    0.005132  0.021932  0.030314  0.014892
recall  0.005132  0.047848  0.081126  0.025993





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 13, global step 168: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 13, global step 168: 'recall@10' was not in top 1


k              1        10        20         5
map     0.004636  0.013968  0.016732  0.011278
ndcg    0.004636  0.021185  0.031405  0.014533
recall  0.004636  0.045364  0.086093  0.024503





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 14, global step 180: 'recall@10' reached 0.04934 (best 0.04934), saving model to '/content/.checkpoints/epoch=14-step=180.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 14, global step 180: 'recall@10' reached 0.04934 (best 0.04934), saving model to '/content/.checkpoints/epoch=14-step=180.ckpt' as top 1


k              1        10        20         5
map     0.006291  0.016077  0.018615  0.013107
ndcg    0.006291  0.023711  0.033097  0.016404
recall  0.006291  0.049338  0.086755  0.026490





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 15, global step 192: 'recall@10' reached 0.05348 (best 0.05348), saving model to '/content/.checkpoints/epoch=15-step=192.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 15, global step 192: 'recall@10' reached 0.05348 (best 0.05348), saving model to '/content/.checkpoints/epoch=15-step=192.ckpt' as top 1


k              1        10        20         5
map     0.006291  0.017427  0.020107  0.014241
ndcg    0.006291  0.025728  0.035731  0.017889
recall  0.006291  0.053477  0.093543  0.028974





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 16, global step 204: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 16, global step 204: 'recall@10' was not in top 1


k              1        10        20         5
map     0.007616  0.017088  0.020155  0.013985
ndcg    0.007616  0.024782  0.036261  0.017181
recall  0.007616  0.050662  0.096689  0.026987





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 17, global step 216: 'recall@10' reached 0.05397 (best 0.05397), saving model to '/content/.checkpoints/epoch=17-step=216.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 17, global step 216: 'recall@10' reached 0.05397 (best 0.05397), saving model to '/content/.checkpoints/epoch=17-step=216.ckpt' as top 1


k              1        10        20         5
map     0.005629  0.016860  0.019497  0.013921
ndcg    0.005629  0.025416  0.035368  0.018137
recall  0.005629  0.053974  0.094040  0.031126





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 18, global step 228: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 18, global step 228: 'recall@10' was not in top 1


k             1        10        20         5
map     0.00596  0.016437  0.019616  0.013240
ndcg    0.00596  0.024695  0.036504  0.016750
recall  0.00596  0.052483  0.099669  0.027483





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 19, global step 240: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 19, global step 240: 'recall@10' was not in top 1


k              1        10        20         5
map     0.005629  0.016464  0.019555  0.013104
ndcg    0.005629  0.024912  0.036310  0.016805
recall  0.005629  0.053146  0.098510  0.028146





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 20, global step 252: 'recall@10' reached 0.05679 (best 0.05679), saving model to '/content/.checkpoints/epoch=20-step=252.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 20, global step 252: 'recall@10' reached 0.05679 (best 0.05679), saving model to '/content/.checkpoints/epoch=20-step=252.ckpt' as top 1


k              1        10        20         5
map     0.005795  0.017373  0.020551  0.013794
ndcg    0.005795  0.026419  0.038310  0.017653
recall  0.005795  0.056788  0.104470  0.029470





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 21, global step 264: 'recall@10' reached 0.06159 (best 0.06159), saving model to '/content/.checkpoints/epoch=21-step=264.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 21, global step 264: 'recall@10' reached 0.06159 (best 0.06159), saving model to '/content/.checkpoints/epoch=21-step=264.ckpt' as top 1


k              1        10        20         5
map     0.007616  0.019732  0.023028  0.016200
ndcg    0.007616  0.029359  0.041561  0.020703
recall  0.007616  0.061589  0.110265  0.034603





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 22, global step 276: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 22, global step 276: 'recall@10' was not in top 1


k              1        10        20         5
map     0.007285  0.019021  0.022360  0.015292
ndcg    0.007285  0.028507  0.041041  0.019339
recall  0.007285  0.060430  0.110762  0.031788





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 23, global step 288: 'recall@10' reached 0.06606 (best 0.06606), saving model to '/content/.checkpoints/epoch=23-step=288.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 23, global step 288: 'recall@10' reached 0.06606 (best 0.06606), saving model to '/content/.checkpoints/epoch=23-step=288.ckpt' as top 1


k             1        10        20         5
map     0.00596  0.019264  0.022998  0.014903
ndcg    0.00596  0.029966  0.043928  0.019222
recall  0.00596  0.066060  0.122020  0.032450





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 24, global step 300: 'recall@10' reached 0.07699 (best 0.07699), saving model to '/content/.checkpoints/epoch=24-step=300.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 24, global step 300: 'recall@10' reached 0.07699 (best 0.07699), saving model to '/content/.checkpoints/epoch=24-step=300.ckpt' as top 1


k             1        10        20         5
map     0.00894  0.024267  0.028347  0.019605
ndcg    0.00894  0.036374  0.051539  0.024954
recall  0.00894  0.076987  0.137583  0.041391





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 25, global step 312: 'recall@10' reached 0.09619 (best 0.09619), saving model to '/content/.checkpoints/epoch=25-step=312.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 25, global step 312: 'recall@10' reached 0.09619 (best 0.09619), saving model to '/content/.checkpoints/epoch=25-step=312.ckpt' as top 1


k              1        10        20         5
map     0.012748  0.031274  0.036400  0.025538
ndcg    0.012748  0.046151  0.065072  0.032019
recall  0.012748  0.096192  0.171523  0.051987





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 26, global step 324: 'recall@10' reached 0.11159 (best 0.11159), saving model to '/content/.checkpoints/epoch=26-step=324.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 26, global step 324: 'recall@10' reached 0.11159 (best 0.11159), saving model to '/content/.checkpoints/epoch=26-step=324.ckpt' as top 1


k              1        10        20         5
map     0.014735  0.036338  0.041971  0.029754
ndcg    0.014735  0.053589  0.074488  0.037353
recall  0.014735  0.111589  0.195033  0.060762





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 27, global step 336: 'recall@10' reached 0.12483 (best 0.12483), saving model to '/content/.checkpoints/epoch=27-step=336.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 27, global step 336: 'recall@10' reached 0.12483 (best 0.12483), saving model to '/content/.checkpoints/epoch=27-step=336.ckpt' as top 1


k              1        10        20         5
map     0.016225  0.040914  0.046262  0.033855
ndcg    0.016225  0.060223  0.079969  0.042948
recall  0.016225  0.124834  0.203477  0.071026





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 28, global step 348: 'recall@10' reached 0.12500 (best 0.12500), saving model to '/content/.checkpoints/epoch=28-step=348.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 28, global step 348: 'recall@10' reached 0.12500 (best 0.12500), saving model to '/content/.checkpoints/epoch=28-step=348.ckpt' as top 1


k             1        10        20         5
map     0.01755  0.042450  0.048453  0.035510
ndcg    0.01755  0.061448  0.083605  0.044353
recall  0.01755  0.125000  0.213245  0.071523





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 29, global step 360: 'recall@10' reached 0.13626 (best 0.13626), saving model to '/content/.checkpoints/epoch=29-step=360.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 29, global step 360: 'recall@10' reached 0.13626 (best 0.13626), saving model to '/content/.checkpoints/epoch=29-step=360.ckpt' as top 1


k              1        10        20         5
map     0.021689  0.047598  0.053546  0.039964
ndcg    0.021689  0.068005  0.090124  0.049397
recall  0.021689  0.136258  0.224669  0.078477





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 30, global step 372: 'recall@10' reached 0.14255 (best 0.14255), saving model to '/content/.checkpoints/epoch=30-step=372.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 30, global step 372: 'recall@10' reached 0.14255 (best 0.14255), saving model to '/content/.checkpoints/epoch=30-step=372.ckpt' as top 1


k              1        10        20         5
map     0.020695  0.049248  0.055074  0.041565
ndcg    0.020695  0.070773  0.092216  0.051866
recall  0.020695  0.142550  0.227815  0.083444





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 31, global step 384: 'recall@10' reached 0.15182 (best 0.15182), saving model to '/content/.checkpoints/epoch=31-step=384.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 31, global step 384: 'recall@10' reached 0.15182 (best 0.15182), saving model to '/content/.checkpoints/epoch=31-step=384.ckpt' as top 1


k              1        10        20         5
map     0.023013  0.053349  0.059941  0.045235
ndcg    0.023013  0.076111  0.100300  0.056329
recall  0.023013  0.151821  0.247848  0.090397





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 32, global step 396: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 32, global step 396: 'recall@10' was not in top 1


k              1        10        20         5
map     0.022517  0.052341  0.058435  0.044142
ndcg    0.022517  0.074825  0.097044  0.054707
recall  0.022517  0.149834  0.237748  0.087086





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 33, global step 408: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 33, global step 408: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02351  0.053794  0.059941  0.046258
ndcg    0.02351  0.075946  0.098534  0.057438
recall  0.02351  0.149503  0.239238  0.091722





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 34, global step 420: 'recall@10' reached 0.15381 (best 0.15381), saving model to '/content/.checkpoints/epoch=34-step=420.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 34, global step 420: 'recall@10' reached 0.15381 (best 0.15381), saving model to '/content/.checkpoints/epoch=34-step=420.ckpt' as top 1


k              1        10        20         5
map     0.024007  0.054409  0.060695  0.046272
ndcg    0.024007  0.077376  0.100578  0.057572
recall  0.024007  0.153808  0.246192  0.092384





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 35, global step 432: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 35, global step 432: 'recall@10' was not in top 1


k              1        10        20         5
map     0.022848  0.054246  0.060352  0.046416
ndcg    0.022848  0.077278  0.099787  0.058036
recall  0.022848  0.153808  0.243377  0.093709





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 36, global step 444: 'recall@10' reached 0.15861 (best 0.15861), saving model to '/content/.checkpoints/epoch=36-step=444.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 36, global step 444: 'recall@10' reached 0.15861 (best 0.15861), saving model to '/content/.checkpoints/epoch=36-step=444.ckpt' as top 1


k              1        10        20         5
map     0.024669  0.057235  0.063697  0.048907
ndcg    0.024669  0.080695  0.104503  0.060354
recall  0.024669  0.158609  0.253311  0.095364





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 37, global step 456: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 37, global step 456: 'recall@10' was not in top 1


k              1        10        20         5
map     0.026821  0.058955  0.065455  0.051576
ndcg    0.026821  0.081920  0.105743  0.063695
recall  0.026821  0.157947  0.252483  0.100828





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 38, global step 468: 'recall@10' reached 0.16507 (best 0.16507), saving model to '/content/.checkpoints/epoch=38-step=468.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 38, global step 468: 'recall@10' reached 0.16507 (best 0.16507), saving model to '/content/.checkpoints/epoch=38-step=468.ckpt' as top 1


k              1        10        20         5
map     0.027318  0.060090  0.066678  0.052152
ndcg    0.027318  0.084378  0.108726  0.064743
recall  0.027318  0.165066  0.262086  0.103477





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 39, global step 480: 'recall@10' reached 0.17003 (best 0.17003), saving model to '/content/.checkpoints/epoch=39-step=480.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 39, global step 480: 'recall@10' reached 0.17003 (best 0.17003), saving model to '/content/.checkpoints/epoch=39-step=480.ckpt' as top 1


k             1        10        20         5
map     0.02947  0.064175  0.070595  0.056206
ndcg    0.02947  0.088745  0.112446  0.069078
recall  0.02947  0.170033  0.264404  0.108444





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 40, global step 492: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 40, global step 492: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028146  0.062237  0.068977  0.053540
ndcg    0.028146  0.086939  0.111699  0.065726
recall  0.028146  0.168874  0.267219  0.102980





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 41, global step 504: 'recall@10' reached 0.17533 (best 0.17533), saving model to '/content/.checkpoints/epoch=41-step=504.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 41, global step 504: 'recall@10' reached 0.17533 (best 0.17533), saving model to '/content/.checkpoints/epoch=41-step=504.ckpt' as top 1


k              1        10        20         5
map     0.029801  0.064896  0.071634  0.056095
ndcg    0.029801  0.090500  0.115420  0.069094
recall  0.029801  0.175331  0.274669  0.108940





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 42, global step 516: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 42, global step 516: 'recall@10' was not in top 1


k              1        10        20         5
map     0.027318  0.062043  0.068611  0.053198
ndcg    0.027318  0.087069  0.111175  0.065368
recall  0.027318  0.170199  0.265894  0.102483





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 43, global step 528: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 43, global step 528: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031788  0.065349  0.071937  0.057249
ndcg    0.031788  0.089488  0.113810  0.069752
recall  0.031788  0.169371  0.266225  0.108113





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 44, global step 540: 'recall@10' reached 0.17632 (best 0.17632), saving model to '/content/.checkpoints/epoch=44-step=540.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 44, global step 540: 'recall@10' reached 0.17632 (best 0.17632), saving model to '/content/.checkpoints/epoch=44-step=540.ckpt' as top 1


k              1        10        20         5
map     0.030298  0.067346  0.074058  0.059015
ndcg    0.030298  0.092702  0.117618  0.072298
recall  0.030298  0.176325  0.275828  0.112748





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 45, global step 552: 'recall@10' reached 0.18411 (best 0.18411), saving model to '/content/.checkpoints/epoch=45-step=552.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 45, global step 552: 'recall@10' reached 0.18411 (best 0.18411), saving model to '/content/.checkpoints/epoch=45-step=552.ckpt' as top 1


k             1        10        20         5
map     0.03245  0.069535  0.075965  0.060560
ndcg    0.03245  0.096065  0.119752  0.073948
recall  0.03245  0.184106  0.278311  0.114901





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 46, global step 564: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 46, global step 564: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028974  0.065052  0.071636  0.055146
ndcg    0.028974  0.092147  0.116188  0.067967
recall  0.028974  0.182450  0.277649  0.107285





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 47, global step 576: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 47, global step 576: 'recall@10' was not in top 1


k              1        10        20         5
map     0.029139  0.064901  0.071360  0.056018
ndcg    0.029139  0.091183  0.114956  0.069227
recall  0.029139  0.178642  0.273179  0.109768





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 48, global step 588: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 48, global step 588: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033444  0.069528  0.076559  0.060853
ndcg    0.033444  0.095212  0.121104  0.074119
recall  0.033444  0.180132  0.283113  0.114735





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 49, global step 600: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 49, global step 600: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031457  0.067272  0.073889  0.058579
ndcg    0.031457  0.093154  0.117923  0.071821
recall  0.031457  0.178974  0.278311  0.112417





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 50, global step 612: 'recall@10' reached 0.18593 (best 0.18593), saving model to '/content/.checkpoints/epoch=50-step=612.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 50, global step 612: 'recall@10' reached 0.18593 (best 0.18593), saving model to '/content/.checkpoints/epoch=50-step=612.ckpt' as top 1


k              1        10        20         5
map     0.029636  0.067998  0.075062  0.058891
ndcg    0.029636  0.095366  0.121299  0.073084
recall  0.029636  0.185927  0.288907  0.116556





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 51, global step 624: 'recall@10' reached 0.18957 (best 0.18957), saving model to '/content/.checkpoints/epoch=51-step=624.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 51, global step 624: 'recall@10' reached 0.18957 (best 0.18957), saving model to '/content/.checkpoints/epoch=51-step=624.ckpt' as top 1


k              1        10        20         5
map     0.033609  0.070652  0.077752  0.060720
ndcg    0.033609  0.098097  0.124212  0.073704
recall  0.033609  0.189570  0.293377  0.113411





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 52, global step 636: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 52, global step 636: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031291  0.067403  0.074235  0.058720
ndcg    0.031291  0.093428  0.118794  0.072325
recall  0.031291  0.179470  0.280795  0.114073





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 53, global step 648: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 53, global step 648: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033775  0.070281  0.077449  0.060593
ndcg    0.033775  0.097756  0.123986  0.074047
recall  0.033775  0.189238  0.293212  0.115397





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 54, global step 660: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 54, global step 660: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030464  0.065931  0.072717  0.055996
ndcg    0.030464  0.092698  0.117925  0.068456
recall  0.030464  0.181954  0.282781  0.106623





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 55, global step 672: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 55, global step 672: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028146  0.064655  0.072054  0.055486
ndcg    0.028146  0.091501  0.118762  0.069032
recall  0.028146  0.180629  0.289073  0.110596





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 56, global step 684: 'recall@10' reached 0.18990 (best 0.18990), saving model to '/content/.checkpoints/epoch=56-step=684.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 56, global step 684: 'recall@10' reached 0.18990 (best 0.18990), saving model to '/content/.checkpoints/epoch=56-step=684.ckpt' as top 1


k              1        10        20         5
map     0.032285  0.069273  0.076342  0.059183
ndcg    0.032285  0.097081  0.122994  0.072245
recall  0.032285  0.189901  0.292715  0.112252





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 57, global step 696: 'recall@10' reached 0.19222 (best 0.19222), saving model to '/content/.checkpoints/epoch=57-step=696.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 57, global step 696: 'recall@10' reached 0.19222 (best 0.19222), saving model to '/content/.checkpoints/epoch=57-step=696.ckpt' as top 1


k              1        10        20         5
map     0.032781  0.070535  0.077375  0.060786
ndcg    0.032781  0.098656  0.123868  0.074694
recall  0.032781  0.192219  0.292550  0.117384





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 58, global step 708: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 58, global step 708: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034272  0.071534  0.078906  0.061752
ndcg    0.034272  0.098690  0.125805  0.074851
recall  0.034272  0.188907  0.296689  0.114901





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 59, global step 720: 'recall@10' reached 0.19288 (best 0.19288), saving model to '/content/.checkpoints/epoch=59-step=720.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 59, global step 720: 'recall@10' reached 0.19288 (best 0.19288), saving model to '/content/.checkpoints/epoch=59-step=720.ckpt' as top 1


k              1        10        20         5
map     0.033113  0.071649  0.078468  0.062196
ndcg    0.033113  0.099731  0.124777  0.076477
recall  0.033113  0.192881  0.292384  0.120199





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 60, global step 732: 'recall@10' reached 0.19371 (best 0.19371), saving model to '/content/.checkpoints/epoch=60-step=732.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 60, global step 732: 'recall@10' reached 0.19371 (best 0.19371), saving model to '/content/.checkpoints/epoch=60-step=732.ckpt' as top 1


k             1        10        20         5
map     0.03394  0.072342  0.079337  0.062867
ndcg    0.03394  0.100414  0.126128  0.076947
recall  0.03394  0.193709  0.295861  0.120033





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 61, global step 744: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 61, global step 744: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032285  0.069600  0.076772  0.060370
ndcg    0.032285  0.096342  0.122769  0.073721
recall  0.032285  0.185099  0.290232  0.114570





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 62, global step 756: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 62, global step 756: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031788  0.071211  0.078158  0.061542
ndcg    0.031788  0.099495  0.125175  0.075520
recall  0.031788  0.193543  0.295861  0.118212





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 63, global step 768: 'recall@10' reached 0.20000 (best 0.20000), saving model to '/content/.checkpoints/epoch=63-step=768.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 63, global step 768: 'recall@10' reached 0.20000 (best 0.20000), saving model to '/content/.checkpoints/epoch=63-step=768.ckpt' as top 1


k              1        10        20         5
map     0.033775  0.074159  0.081170  0.064227
ndcg    0.033775  0.103309  0.129038  0.079040
recall  0.033775  0.200000  0.302152  0.124503





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 64, global step 780: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 64, global step 780: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031788  0.069865  0.077038  0.060091
ndcg    0.031788  0.097699  0.124157  0.073712
recall  0.031788  0.190232  0.295530  0.115397





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 65, global step 792: 'recall@10' reached 0.20430 (best 0.20430), saving model to '/content/.checkpoints/epoch=65-step=792.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 65, global step 792: 'recall@10' reached 0.20430 (best 0.20430), saving model to '/content/.checkpoints/epoch=65-step=792.ckpt' as top 1


k              1        10        20         5
map     0.032616  0.073643  0.080836  0.062897
ndcg    0.032616  0.103837  0.130316  0.077530
recall  0.032616  0.204305  0.309603  0.122351





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 66, global step 804: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 66, global step 804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.036424  0.075393  0.082858  0.066087
ndcg    0.036424  0.103789  0.131267  0.080838
recall  0.036424  0.198013  0.307285  0.126159





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 67, global step 816: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 67, global step 816: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.074076  0.081002  0.064067
ndcg    0.035265  0.102692  0.128206  0.078314
recall  0.035265  0.197682  0.299172  0.122020





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 68, global step 828: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 68, global step 828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032119  0.070933  0.078118  0.060889
ndcg    0.032119  0.099568  0.126038  0.074841
recall  0.032119  0.194868  0.300166  0.117550





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 69, global step 840: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 69, global step 840: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034106  0.072909  0.079876  0.062420
ndcg    0.034106  0.102035  0.127712  0.076581
recall  0.034106  0.198841  0.300993  0.120033





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 70, global step 852: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 70, global step 852: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031457  0.072157  0.079917  0.062351
ndcg    0.031457  0.101161  0.129642  0.077207
recall  0.031457  0.197185  0.310265  0.122682





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 71, global step 864: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 71, global step 864: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031126  0.070542  0.078271  0.061347
ndcg    0.031126  0.098568  0.127232  0.076011
recall  0.031126  0.191225  0.305629  0.120861





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 72, global step 876: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 72, global step 876: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03394  0.074228  0.081458  0.063626
ndcg    0.03394  0.103887  0.130554  0.077789
recall  0.03394  0.202649  0.308775  0.121026





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 73, global step 888: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 73, global step 888: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03543  0.074925  0.082577  0.064912
ndcg    0.03543  0.104017  0.132145  0.079354
recall  0.03543  0.200828  0.312583  0.123675





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 74, global step 900: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 74, global step 900: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.074551  0.081775  0.064111
ndcg    0.035265  0.104093  0.130807  0.078465
recall  0.035265  0.202483  0.308940  0.122517





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 75, global step 912: 'recall@10' reached 0.20546 (best 0.20546), saving model to '/content/.checkpoints/epoch=75-step=912.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 75, global step 912: 'recall@10' reached 0.20546 (best 0.20546), saving model to '/content/.checkpoints/epoch=75-step=912.ckpt' as top 1


k              1        10        20         5
map     0.034603  0.076272  0.083386  0.066319
ndcg    0.034603  0.106207  0.132592  0.081730
recall  0.034603  0.205464  0.310762  0.128974





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 76, global step 924: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 76, global step 924: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034934  0.075420  0.082494  0.065381
ndcg    0.034934  0.104618  0.130755  0.080022
recall  0.034934  0.201490  0.305629  0.124834





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 77, global step 936: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 77, global step 936: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033775  0.074993  0.082699  0.064931
ndcg    0.033775  0.104188  0.132520  0.079787
recall  0.033775  0.200662  0.313245  0.125166





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 78, global step 948: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 78, global step 948: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031788  0.072153  0.080078  0.062216
ndcg    0.031788  0.101363  0.130298  0.077057
recall  0.031788  0.198179  0.312748  0.122517





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 79, global step 960: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 79, global step 960: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033113  0.072883  0.080432  0.063342
ndcg    0.033113  0.101745  0.129520  0.078399
recall  0.033113  0.197351  0.307781  0.124669





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 80, global step 972: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 80, global step 972: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032616  0.073385  0.081309  0.062525
ndcg    0.032616  0.103907  0.133061  0.077421
recall  0.032616  0.205464  0.321358  0.123179





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 81, global step 984: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 81, global step 984: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.074764  0.082254  0.065224
ndcg    0.035265  0.103850  0.131413  0.080413
recall  0.035265  0.200331  0.309934  0.127152





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 82, global step 996: 'recall@10' reached 0.21010 (best 0.21010), saving model to '/content/.checkpoints/epoch=82-step=996.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 82, global step 996: 'recall@10' reached 0.21010 (best 0.21010), saving model to '/content/.checkpoints/epoch=82-step=996.ckpt' as top 1


k              1        10        20         5
map     0.033278  0.077268  0.084396  0.067279
ndcg    0.033278  0.108130  0.134585  0.083583
recall  0.033278  0.210099  0.315728  0.133444





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 83, global step 1008: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 83, global step 1008: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031954  0.074447  0.082000  0.064658
ndcg    0.031954  0.104735  0.132629  0.080638
recall  0.031954  0.204967  0.316060  0.129636





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 84, global step 1020: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 84, global step 1020: 'recall@10' was not in top 1


k             1        10        20         5
map     0.03394  0.076469  0.084025  0.066338
ndcg    0.03394  0.107034  0.134680  0.082345
recall  0.03394  0.208113  0.317715  0.131457





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 85, global step 1032: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 85, global step 1032: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032119  0.073359  0.080498  0.062384
ndcg    0.032119  0.104246  0.130548  0.077326
recall  0.032119  0.207119  0.311755  0.123179





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 86, global step 1044: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 86, global step 1044: 'recall@10' was not in top 1


k              1        10        20         5
map     0.034106  0.073927  0.080933  0.063595
ndcg    0.034106  0.103222  0.129138  0.078048
recall  0.034106  0.200497  0.303808  0.122351





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 87, global step 1056: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 87, global step 1056: 'recall@10' was not in top 1


k              1        10        20         5
map     0.037086  0.076446  0.083854  0.066438
ndcg    0.037086  0.105430  0.132836  0.080989
recall  0.037086  0.201656  0.310927  0.125662





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 88, global step 1068: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 88, global step 1068: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031457  0.072669  0.080228  0.062955
ndcg    0.031457  0.102468  0.130368  0.078605
recall  0.031457  0.201159  0.312252  0.126656





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 89, global step 1080: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 89, global step 1080: 'recall@10' was not in top 1


k              1        10        20         5
map     0.028974  0.069992  0.077587  0.059953
ndcg    0.028974  0.099354  0.127329  0.074981
recall  0.028974  0.196523  0.307781  0.121026





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 90, global step 1092: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 90, global step 1092: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033113  0.074051  0.081501  0.063013
ndcg    0.033113  0.104553  0.131932  0.077605
recall  0.033113  0.206126  0.314901  0.122351





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 91, global step 1104: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 91, global step 1104: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.076843  0.084199  0.066763
ndcg    0.035265  0.106545  0.133504  0.082157
recall  0.035265  0.204636  0.311589  0.129305





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 92, global step 1116: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 92, global step 1116: 'recall@10' was not in top 1


k              1        10        20         5
map     0.037252  0.078456  0.086259  0.068659
ndcg    0.037252  0.107913  0.136421  0.083844
recall  0.037252  0.205464  0.318377  0.130298





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 93, global step 1128: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 93, global step 1128: 'recall@10' was not in top 1


k              1        10        20         5
map     0.033278  0.073839  0.081365  0.063397
ndcg    0.033278  0.104033  0.131645  0.078436
recall  0.033278  0.204470  0.314073  0.124669





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 94, global step 1140: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 94, global step 1140: 'recall@10' was not in top 1


k              1        10        20         5
map     0.030298  0.071546  0.078808  0.060706
ndcg    0.030298  0.102022  0.128944  0.075431
recall  0.030298  0.203477  0.310927  0.120530





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 95, global step 1152: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 95, global step 1152: 'recall@10' was not in top 1


k              1        10        20         5
map     0.031954  0.073688  0.081289  0.063146
ndcg    0.031954  0.104495  0.132422  0.078706
recall  0.031954  0.206788  0.317715  0.126490





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 96, global step 1164: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 96, global step 1164: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035265  0.075622  0.083117  0.064426
ndcg    0.035265  0.105984  0.133588  0.078530
recall  0.035265  0.207285  0.317053  0.121689





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 97, global step 1176: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 97, global step 1176: 'recall@10' was not in top 1


k              1        10        20         5
map     0.035762  0.075370  0.082993  0.064884
ndcg    0.035762  0.104329  0.132340  0.078659
recall  0.035762  0.200662  0.311921  0.120695





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 98, global step 1188: 'recall@10' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 98, global step 1188: 'recall@10' was not in top 1


k              1        10        20         5
map     0.032947  0.075710  0.083447  0.065400
ndcg    0.032947  0.106332  0.134680  0.081025
recall  0.032947  0.207781  0.320199  0.128808





Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 99, global step 1200: 'recall@10' reached 0.21192 (best 0.21192), saving model to '/content/.checkpoints/epoch=99-step=1200.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 99, global step 1200: 'recall@10' reached 0.21192 (best 0.21192), saving model to '/content/.checkpoints/epoch=99-step=1200.ckpt' as top 1


k              1        10        20         5
map     0.035596  0.078605  0.086401  0.068253
ndcg    0.035596  0.109491  0.138429  0.084049
recall  0.035596  0.211921  0.327483  0.132450



INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


# Get Results

In [158]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

In [159]:
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="Bert4Rec_example")



In [160]:
TOPK = [10]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

query_embeddings_callback = QueryEmbeddingsPredictionCallback()

trainer = L.Trainer(
    callbacks=[
        pandas_prediction_callback,
        query_embeddings_callback,
    ],
    logger=csv_logger,
    inference_mode=True
)
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [161]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,790,2003,10.521249
0,790,434,10.419925
0,790,2551,10.212043
0,790,2543,10.068316
0,790,3699,10.067029
...,...,...,...
6039,4520,2609,6.179309
6039,4520,2389,6.158994
6039,4520,1177,6.128626
6039,4520,3655,6.088515


In [162]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(pandas_res)

In [163]:
recommendations

Unnamed: 0,user_id,item_id,score
0,0,2003,10.521249
0,0,434,10.419925
0,0,2551,10.212043
0,0,2543,10.068316
0,0,3699,10.067029
...,...,...,...
6039,6039,2609,6.179309
6039,6039,2389,6.158994
6039,6039,1177,6.128626
6039,6039,3655,6.088515


In [164]:
result = recommendations.groupby('user_id')['item_id']\
                    .apply(lambda x: ' '.join(x.astype(str))).reset_index()

In [165]:
result

Unnamed: 0,user_id,item_id
0,0,2003 434 2551 2543 3699 1479 708 3214 3167 2138
1,1,232 1246 3101 1884 3656 1459 560 1822 1686 2476
2,2,234 2311 2774 382 887 1371 1560 221 2354 2428
3,3,3390 3562 1814 3365 94 605 956 810 1893 2908
4,4,1160 983 755 1835 394 3002 2237 3602 672 2925
...,...,...
6035,6035,1706 2387 2265 2148 513 387 3216 2512 3247 880
6036,6036,1859 3142 2470 922 3105 3692 3059 2054 1057 2502
6037,6037,1375 2664 1439 3059 1747 2833 12 2800 1296 2256
6038,6038,316 1514 450 1355 83 3343 3309 336 2231 797


In [166]:
result.to_csv('result_10.csv', index=False)