In [1]:
import argparse
from logging import getLogger
import os
from recbole.config import Config
from recbole.data import create_dataset
from recbole.data.utils import get_dataloader, create_samplers
from recbole.model.sequential_recommender.mbht import MBHT
from recbole.utils import init_logger, init_seed, get_model, get_trainer, set_color

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', '-m', type=str, default='MBHT', help='Model for session-based rec.')
    parser.add_argument('--dataset', '-d', type=str, default='tmall_beh', help='Benchmarks for session-based rec.')
    parser.add_argument('--validation', action='store_true', help='Whether evaluating on validation set (split from train set), otherwise on test set.')
    parser.add_argument('--valid_portion', type=float, default=0.1, help='ratio of validation set.')
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=2048)
    return parser.parse_known_args()[0]

In [3]:
args = get_args()

In [4]:
config_dict = {
    'USER_ID_FIELD': 'session_id',
    'load_col': None,
    # 'neg_sampling': {'uniform':1},
    'neg_sampling': None,
    'benchmark_filename': ['train', 'test'],
    'alias_of_item_id': ['item_id_list'],
    'topk': [5, 10, 101],
    'metrics': ['Recall', 'NDCG', 'MRR'],
    'valid_metric': 'NDCG@10',
    'eval_args':{
        'mode':'full',
        'order':'TO'
        },
    'gpu_id':args.gpu_id,
    "MAX_ITEM_LIST_LENGTH":200,
    "train_batch_size": 32 if args.dataset == "ijcai_beh" else 64,
    "eval_batch_size":24 if args.dataset == "ijcai_beh" else 128,
    "hyper_len":10 if args.dataset == "ijcai_beh" else 6,
    "scales":[10, 4, 20],
    "enable_hg":1,
    "enable_ms":1,
    "customized_eval":1,
    "abaltion":""
}

In [5]:
if args.dataset == "retail_beh":
    config_dict['scales'] = [5, 4, 20]
    config_dict['hyper_len'] = 6

In [6]:
config = Config(model="MBHT", dataset=f'{args.dataset}', config_dict=config_dict)
# config['device']="cpu"
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config, log_root="log")
logger = getLogger()

logger.info(f"PID: {os.getpid()}")
logger.info(args)
logger.info(config)

# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)

01 Aug 19:51    INFO  PID: 22352
01 Aug 19:51    INFO  Namespace(batch_size=2048, dataset='tmall_beh', gpu_id=0, model='MBHT', valid_portion=0.1, validation=False)
01 Aug 19:51    INFO  
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = INFO
reproducibility = True
data_path = dataset/tmall_beh
show_progress = True
save_dataset = False
save_dataloaders = False
benchmark_filename = ['train', 'test']

Training Hyper Parameters:
checkpoint_dir = saved
epochs = 300
train_batch_size = 64
learner = adam
learning_rate = 0.001
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'mode': 'full', 'order': 'TO', 'split': {'RS': [0.8, 0.1, 0.1]}, 'group_by': 'user'}
metrics = ['Recall', 'NDCG', 'MRR']
topk = [5, 10, 101]
valid_metric = NDCG@10
valid_metric_bigger = True
eval_batch_size = 128
metric_decimal_place = 4

Dataset Hyper Parameters:
field_separator = 	
seq_separator =  
USER_ID_FIE

In [7]:
# dataset splitting
train_dataset, test_dataset = dataset.build()
train_sampler, test_sampler = create_samplers(config, dataset, [train_dataset, test_dataset])
if args.validation:
    train_dataset.shuffle()
    new_train_dataset, new_test_dataset = train_dataset.split_by_ratio([1 - args.valid_portion, args.valid_portion])
    train_data = get_dataloader(config, 'train')(config, new_train_dataset, None, shuffle=True)
    test_data = get_dataloader(config, 'test')(config, new_test_dataset, None, shuffle=False)
else:
    train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True)
    test_data = get_dataloader(config, 'test')(config, test_dataset, test_sampler, shuffle=False)


In [9]:
print(train_data)

<recbole.data.dataloader.general_dataloader.TrainDataLoader object at 0x0000027C2A654D48>


In [10]:
# model loading and initialization
model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
logger.info(model)

# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

01 Aug 19:51    INFO  MBHT(
  (item_embedding_ls): Embedding(99039, 64, padding_idx=0)
  (sequenceMixer_1): PreNormResidual(
    (fn): Sequential(
      (0): Conv1d(200, 800, kernel_size=(1,), stride=(1,))
      (1): GELU(approximate=none)
      (2): Dropout(p=0.5, inplace=False)
      (3): Conv1d(800, 200, kernel_size=(1,), stride=(1,))
      (4): Dropout(p=0.5, inplace=False)
    )
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (channelMixer_1): PreNormResidual(
    (fn): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): GELU(approximate=none)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=256, out_features=64, bias=True)
      (4): Dropout(p=0.5, inplace=False)
    )
    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (LayerNorm_1): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
  (sequenceMixer_2): PreNormResidual(
    (fn): Sequential(
      (0): Conv1d(200, 800, kernel_siz

In [11]:
print(trainer)

<recbole.trainer.trainer.Trainer object at 0x0000027C34A68608>


In [14]:
show_progress=config['show_progress']

In [15]:
from tqdm import tqdm

In [16]:
for epoch_idx in range(0, config['epochs']):
    iter_data = (
            tqdm(
                train_data,
                total=len(train_data),
                ncols=100,
                desc=set_color(f"Train {epoch_idx:>5}", 'pink'),
            ) if show_progress else train_data
        )


Train     0:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train     0:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train     1:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train     2:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train     3:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train     4:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                              

Train    42:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train    43:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train    44:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train    45:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train    46:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train    47:   0%|                              

Train    85:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train    86:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train    87:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train    88:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train    89:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train    90:   0%|                             

Train   128:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   129:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   130:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   131:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   132:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   133:   0%|                              

Train   171:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   172:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   173:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   174:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   175:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   176:   0%|                             

Train   214:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   215:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   216:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   217:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   218:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   219:   0%|                              

Train   257:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   258:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   259:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   260:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]
Train   261:   0%|                                                         | 0/7234 [00:00<?, ?it/s]:   0%|                                                         | 0/7234 [00:00<?, ?it/s]

Train   262:   0%|                             

In [19]:
for batch_idx, interaction in enumerate(iter_data):
    interaction = interaction
    print(interaction)
    #item_seq = interaction[self.ITEM_SEQ]
    item_seq_len = interaction['item_length']
    #session_id = interaction['session_id']
    #item_type = interaction["item_type_list"]
    #last_buy = interaction["item_id"]
    #print(interaction)
    #print(item_seq)
    print(item_seq_len)
    #print(item_type)
    #print(last_buy)
    break
        

The batch_size of interaction: 64
    session_id, torch.Size([64]), cpu, torch.int64
    item_id_list, torch.Size([64, 199]), cpu, torch.int64
    item_type_list, torch.Size([64, 199]), cpu, torch.int64
    item_id, torch.Size([64]), cpu, torch.int64
    item_length, torch.Size([64]), cpu, torch.int64


tensor([ 19,  16,  10,  22,   5,  54,  49,   2,  18,   5,  14,   4,  36,  12,
         28,  12,  44,  46,  11,  21,  54,  42,   6,  12,   2,  43,  30,  14,
         14,  26,   7,   2,  26,   7,  15,  61,  18,  67,   6,  67,  58,  74,
        122,  55,  52,  52,  20,  51,   9, 102,  24,  13,  63,  11,  34,  18,
         22,   1,  86,  49,  26,  38, 103,  23])


In [None]:
from torch import torch
print(torch.__version__)

In [None]:
item_seq_len = interaction[self.ITEM_SEQ_LEN]