In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import sys
sys.path.append('../')
print(sys.path)

['/home/hd31/anaconda3/envs/py312/lib/python312.zip', '/home/hd31/anaconda3/envs/py312/lib/python3.12', '/home/hd31/anaconda3/envs/py312/lib/python3.12/lib-dynload', '', '/home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages', '__editable__.recommenders-1.2.1.finder.__path_hook__', '../']


In [2]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from torch.utils.data import DataLoader
from mamba_ssm import Mamba
from src.datasets import (CausalLMDataset, CausalLMPredictionDataset, MaskedLMDataset,
                          MaskedLMPredictionDataset, PaddingCollateFn)
from src.metrics import compute_metrics
from src.models import RNN, BERT4Rec, SASRec, MAMBA4Rec
from src.modules import SeqRec, SeqRecWithSampling
from src.postprocess import preds2recs
from src.preprocess import add_time_idx
print("Successfully imported package")

Successfully imported package


## Config

In [6]:
# # for SASRec
# with initialize(version_base=None, config_path="../src/configs/"):
#     config = compose(config_name="SASRec")

# # for BERT4Rec
# with initialize(version_base=None, config_path="../src/configs/"):
#     config = compose(config_name="BERT4Rec")

# # for GRU4Rec
# with initialize(version_base=None, config_path="../src/configs/"):
#     config = compose(config_name="RNN")

# for MAMBA4Rec
with initialize(version_base=None, config_path="../src/configs/"):
    config = compose(config_name="MAMBA4Rec")

print(config)

{'cuda_visible_devices': 0, 'data_path': '../data/ml-1m.txt', 'dataset': {'max_length': 50, 'full_negative_sampling': False}, 'dataloader': {'batch_size': 128, 'test_batch_size': 256, 'num_workers': 8, 'validation_size': 10000}, 'model': 'MAMBA4Rec', 'model_params': {'maxlen': 200, 'hidden_units': 64, 'num_blocks': 2, 'dropout_rate': 0.1, 'initializer_range': 0.02, 'mamba_config': {'d_model': 32, 'd_state': 8, 'd_conv': 2, 'expand': 2}}, 'seqrec_module': {'lr': 0.001, 'predict_top_k': 10, 'filter_seen': True}, 'trainer_params': {'max_epochs': 100}, 'patience': 10, 'sampled_metrics': False, 'top_k_metrics': [10, 100], 'hidden_size': 64, 'num_layers': 1, 'dropout_prob': 0.2, 'loss_type': 'CE', 'd_state': 32, 'd_conv': 4, 'expand': 2, 'USER_ID_FIELD': 'user_id', 'ITEM_ID_FIELD': 'item_id', 'load_col': {'inter': ['user_id', 'item_id', 'timestamp']}, 'user_inter_num_interval': '[5,inf)', 'item_inter_num_interval': '[5,inf)', 'epochs': 300, 'train_batch_size': 2048, 'learner': 'adam', 'learn

In [7]:
OmegaConf.set_struct(config, False)

config.dataset.max_length = 200

# # for training with negative sampling
# config.dataset.num_negatives = 1000

# # for original SASRec training with BCE loss and 1 negative example
# config.seqrec_module.loss = 'bce'
# config.dataset.num_negatives = 1
# config.dataset.full_negative_sampling = True

In [8]:
print(OmegaConf.to_yaml(config))

cuda_visible_devices: 0
data_path: ../data/ml-1m.txt
dataset:
  max_length: 200
  full_negative_sampling: false
dataloader:
  batch_size: 128
  test_batch_size: 256
  num_workers: 8
  validation_size: 10000
model: MAMBA4Rec
model_params:
  maxlen: 200
  hidden_units: 64
  num_blocks: 2
  dropout_rate: 0.1
  initializer_range: 0.02
  mamba_config:
    d_model: 32
    d_state: 8
    d_conv: 2
    expand: 2
seqrec_module:
  lr: 0.001
  predict_top_k: 10
  filter_seen: true
trainer_params:
  max_epochs: 100
patience: 10
sampled_metrics: false
top_k_metrics:
- 10
- 100
hidden_size: 64
num_layers: 1
dropout_prob: 0.2
loss_type: CE
d_state: 32
d_conv: 4
expand: 2
USER_ID_FIELD: user_id
ITEM_ID_FIELD: item_id
load_col:
  inter:
  - user_id
  - item_id
  - timestamp
user_inter_num_interval: '[5,inf)'
item_inter_num_interval: '[5,inf)'
epochs: 300
train_batch_size: 2048
learner: adam
learning_rate: 0.001
eval_step: 1
stopping_step: 10
train_neg_sample_args: null
metrics:
- Hit
- NDCG
- MRR
valid

## Load data

In [9]:
data = pd.read_csv(config.data_path, sep=' ', header=None, names=['user_id', 'item_id'])
data = add_time_idx(data, sort=False)

# index 1 is used for masking value
if config.model == 'MAMBA4Rec':
    data.item_id += 1

print(data.shape)
data.head()

(999611, 4)


Unnamed: 0,user_id,item_id,time_idx,time_idx_reversed
0,1,2,0,78
1,1,3,1,77
2,1,4,2,76
3,1,5,3,75
4,1,6,4,74


In [10]:
train = data[data.time_idx_reversed >= 2]
validation = data[data.time_idx_reversed == 1]
validation_full = data[data.time_idx_reversed >= 1]
test = data[data.time_idx_reversed == 0]

## Dataloaders

In [11]:
validation_size = config.dataloader.validation_size
validation_users = validation_full.user_id.unique()
if validation_size and (validation_size < len(validation_users)):
    validation_users = np.random.choice(validation_users, size=validation_size, replace=False)

if config.model in ['SASRec', 'RNN', 'MAMBA4Rec']:
    train_dataset = CausalLMDataset(train, **config['dataset'])
    eval_dataset = CausalLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.dataset.max_length, validation_mode=True)
elif config.model == 'BERT4Rec':
    train_dataset = MaskedLMDataset(train, **config['dataset'])
    eval_dataset = MaskedLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.dataset.max_length, validation_mode=True)

train_loader = DataLoader(
    train_dataset, shuffle=True,
    collate_fn=PaddingCollateFn(),
    batch_size=config.dataloader.batch_size,
    num_workers=config.dataloader.num_workers)
eval_loader = DataLoader(
    eval_dataset, shuffle=False,
    collate_fn=PaddingCollateFn(),
    batch_size=config.dataloader.test_batch_size,
    num_workers=config.dataloader.num_workers)

In [12]:
batch = next(iter(train_loader))
print(batch['input_ids'].shape)

torch.Size([128, 200])


## Model

In [15]:
item_count = data.item_id.max()

from mamba_ssm import Mamba

if hasattr(config.dataset, 'num_negatives') and config.dataset.num_negatives:
    add_head = False
else:
    add_head = True

if config.model == 'SASRec':
    model = SASRec(item_num=item_count, add_head=add_head, **config.model_params)
if config.model == 'BERT4Rec':
    model = BERT4Rec(vocab_size=item_count + 1, add_head=add_head,
                     bert_config=config.model_params)
if config.model == 'RNN':
    model = RNN(vocab_size=item_count + 1, add_head=add_head,
                rnn_config=config.model_params)
elif config.model == 'MAMBA4Rec':
    model = MAMBA4Rec(vocab_size=item_count + 1, add_head=add_head,
                mamba_config=config.model_params.mamba_config)

In [16]:
out = model(batch['input_ids'], batch['attention_mask'])
out.shape

RuntimeError: Expected u.is_cuda() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Exception raised from selective_scan_fwd at /home/runner/work/mamba/mamba/csrc/selective_scan/selective_scan.cpp:246 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f521c0cc788 in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x6a (0x7f521c075fbc in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #2: selective_scan_fwd(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, std::optional<at::Tensor> const&, bool) + 0x22f (0x7f51044c89ff in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/selective_scan_cuda.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x92ed4 (0x7f51044e1ed4 in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/selective_scan_cuda.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x8fc95 (0x7f51044dec95 in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/selective_scan_cuda.cpython-312-x86_64-linux-gnu.so)
frame #5: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x54d2d4]
frame #6: _PyObject_MakeTpCall + 0x2fb (0x51e38b in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #7: _PyEval_EvalFrameDefault + 0x6ce (0x528ede in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #8: THPFunction_apply(_object*, _object*) + 0xe30 (0x7f5213e0fe90 in /home/hd31/anaconda3/envs/py312/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #9: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x54d2fc]
frame #10: _PyObject_Call + 0xb5 (0x55e0f5 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #11: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #12: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x578d57]
frame #13: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #14: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x578d57]
frame #15: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #16: _PyObject_FastCallDictTstate + 0x1e7 (0x520f07 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #17: _PyObject_Call_Prepend + 0x66 (0x55b4c6 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #18: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x630ff6]
frame #19: _PyObject_MakeTpCall + 0x2fb (0x51e38b in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #20: _PyEval_EvalFrameDefault + 0x6ce (0x528ede in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #21: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x578d57]
frame #22: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #23: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x578d57]
frame #24: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #25: _PyObject_FastCallDictTstate + 0x1e7 (0x520f07 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #26: _PyObject_Call_Prepend + 0x66 (0x55b4c6 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #27: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x630ff6]
frame #28: _PyObject_MakeTpCall + 0x2fb (0x51e38b in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #29: _PyEval_EvalFrameDefault + 0x6ce (0x528ede in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #30: PyEval_EvalCode + 0xae (0x5e581e in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #31: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x603679]
frame #32: _PyEval_EvalFrameDefault + 0x3a89 (0x52c299 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #33: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5fe027]
frame #34: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5ff2a6]
frame #35: _PyEval_EvalFrameDefault + 0x4698 (0x52cea8 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #36: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5792ad]
frame #37: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x578dbd]
frame #38: _PyObject_Call + 0x122 (0x55e162 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #39: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #40: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5fe027]
frame #41: <unknown function> + 0x841b (0x7f5272e1941b in /home/hd31/anaconda3/envs/py312/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #42: <unknown function> + 0x8c27 (0x7f5272e19c27 in /home/hd31/anaconda3/envs/py312/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #43: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x54be6b]
frame #44: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x67bfd1]
frame #45: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x4dd071]
frame #46: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5418be]
frame #47: _PyEval_EvalFrameDefault + 0x503a (0x52d84a in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #48: PyEval_EvalCode + 0xae (0x5e581e in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #49: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x603679]
frame #50: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5418be]
frame #51: PyObject_Vectorcall + 0x51 (0x5416a1 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #52: _PyEval_EvalFrameDefault + 0x6ce (0x528ede in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #53: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x6181af]
frame #54: Py_RunMain + 0x3d8 (0x617d68 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #55: Py_BytesMain + 0x39 (0x5d03c9 in /home/hd31/anaconda3/envs/py312/bin/python3.12)
frame #56: <unknown function> + 0x29d90 (0x7f5273fe7d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #57: __libc_start_main + 0x80 (0x7f5273fe7e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #58: /home/hd31/anaconda3/envs/py312/bin/python3.12() [0x5d01f9]


## Train

In [13]:
if hasattr(config.dataset, 'num_negatives') and config.dataset.num_negatives:
    seqrec_module = SeqRecWithSampling(model, **config['seqrec_module'])
else:
    seqrec_module = SeqRec(model, **config['seqrec_module'])
    
early_stopping = EarlyStopping(monitor="val_ndcg", mode="max",
                               patience=config.patience, verbose=False)
model_summary = ModelSummary(max_depth=2)
checkpoint = ModelCheckpoint(save_top_k=1, monitor="val_ndcg",
                             mode="max", save_weights_only=True)
callbacks=[early_stopping, model_summary, checkpoint]

trainer = pl.Trainer(callbacks=callbacks, enable_checkpointing=True,
                     gpus=1, **config['trainer_params'])

trainer.fit(model=seqrec_module,
            train_dataloaders=train_loader,
            val_dataloaders=eval_loader)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name                       | Type       | Params
----------------------------------------------------------
0 | model                      | SASRec     | 282 K 
1 | model.item_emb             | Embedding  | 218 K 
2 | model.pos_emb              | Embedding  | 12.8 K
3 | model.emb_dropout          | Dropout    | 0     
4 | model.attention_layernorms | ModuleList | 256   
5 | model.attention_layers     | ModuleList | 33.3 K
6 | model.forward_layernorms   | ModuleList | 256   
7 | model.forward_layers       | ModuleList | 16.6 K
8 | model.last_layernorm       | LayerNorm  | 128   
----------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [14]:
seqrec_module.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])

<All keys matched successfully>

## Validation metrics

In [15]:
if config.model in ['SASRec', 'RNN']:
    predict_dataset = CausalLMPredictionDataset(train, max_length=config.dataset.max_length)
elif config.model  == 'BERT4Rec':
    predict_dataset = MaskedLMPredictionDataset(train, max_length=config.dataset.max_length)

predict_loader = DataLoader(
        predict_dataset, shuffle=False,
        collate_fn=PaddingCollateFn(),
        batch_size=config.dataloader.test_batch_size,
        num_workers=config.dataloader.num_workers)

seqrec_module.predict_top_k = max(config.top_k_metrics)
preds = trainer.predict(model=seqrec_module, dataloaders=predict_loader)

recs = preds2recs(preds)
print(recs.shape)
recs.head()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Predicting: 48it [00:00, ?it/s]

(604000, 3)


Unnamed: 0,user_id,item_id,prediction
0,1,737,7.598266
1,1,663,7.24827
2,1,102,7.237531
3,1,745,7.148354
4,1,1125,7.065181


In [16]:
for k in config.top_k_metrics:
    metrics_val = compute_metrics(validation, recs, k=k)
    print('k = ', k)
    print(metrics_val)

k =  10
{'ndcg@10': 0.18906012825101937, 'hit_rate@10': 0.33228476821192054, 'mrr@10': 0.14530563439503838}
k =  100
{'ndcg@100': 0.2657238370955717, 'hit_rate@100': 0.7066225165562914, 'mrr@100': 0.15960866987332128}


## Test metrics

In [17]:
if config.model in ['SASRec', 'RNN']:
    test_predict_dataset = CausalLMPredictionDataset(validation_full, max_length=config.dataset.max_length)
elif config.model  == 'BERT4Rec':
    test_predict_dataset = MaskedLMPredictionDataset(validation_full, max_length=config.dataset.max_length)
    
test_predict_loader = DataLoader(
        test_predict_dataset, shuffle=False,
        collate_fn=PaddingCollateFn(),
        batch_size=config.dataloader.test_batch_size,
        num_workers=config.dataloader.num_workers)

seqrec_module.predict_top_k = max(config.top_k_metrics)
preds_test = trainer.predict(model=seqrec_module, dataloaders=test_predict_loader)

recs_test = preds2recs(preds_test)
print(recs_test.shape)
recs_test.head()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Predicting: 48it [00:00, ?it/s]

(604000, 3)


Unnamed: 0,user_id,item_id,prediction
0,1,872,6.629054
1,1,667,6.363072
2,1,745,6.264328
3,1,679,6.156502
4,1,737,6.107877


In [18]:
for k in config.top_k_metrics:
    metrics_test = compute_metrics(test, recs_test, k=k)
    print('k = ', k)
    print(metrics_test)

k =  10
{'ndcg@10': 0.18250363127629773, 'hit_rate@10': 0.31473509933774835, 'mrr@10': 0.14229160096709764}
k =  100
{'ndcg@100': 0.25577861370312055, 'hit_rate@100': 0.6736754966887417, 'mrr@100': 0.15588159170341734}
