In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

import glob
import torch 

from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt
from transformers4rec.torch.utils.examples_utils import wipe_memory

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from merlin.schema import Schema
from merlin.io import Dataset

INPUT_DATA_DIR = os.environ.get("INPUT_DATA_DIR", "/home/ec2-user/SageMaker/token_recommender/data/")

train = Dataset(os.path.join(INPUT_DATA_DIR, "202201-202203/data.parquet/part_0.parquet"))
schema = train.schema

In [4]:
schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.freq_threshold,properties.num_buckets,properties.embedding_sizes.dimension,properties.embedding_sizes.cardinality,properties.cat_path,properties.max_size,properties.domain.min,properties.domain.max,properties.domain.name,properties.value_count.min,properties.value_count.max
0,recipient,(),"DType(name='unknown', element_type=<ElementTyp...",False,False,,,,,,,,,,,
1,timestamp-first,(),"DType(name='unknown', element_type=<ElementTyp...",False,False,,,,,,,,,,,
2,buyAsset-list,"(Tags.CATEGORICAL, Tags.LIST, Tags.ID, Tags.ITEM)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,97.0,1515.0,.//categories/unique.buyAsset.parquet,0.0,0.0,1514.0,buyAsset,2.0,20.0
3,et_dayofweek_sin-list,"(Tags.LIST, Tags.CONTINUOUS)","DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,2.0,20.0
4,txFee_eth_log_norm-list,"(Tags.LIST, Tags.CONTINUOUS)","DType(name='float32', element_type=<ElementTyp...",True,True,,,,,,,,,,2.0,20.0
5,buyQty1_log_norm-list,"(Tags.LIST, Tags.CONTINUOUS)","DType(name='float32', element_type=<ElementTyp...",True,True,,,,,,,,,,2.0,20.0
6,buyPrice_log_norm-list,"(Tags.LIST, Tags.CONTINUOUS)","DType(name='float32', element_type=<ElementTyp...",True,True,,,,,,,,,,2.0,20.0
7,token_category-list,"(Tags.CATEGORICAL, Tags.LIST)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,16.0,48.0,.//categories/unique.token_category.parquet,0.0,0.0,47.0,token_category,2.0,20.0
8,token_rank_category-list,"(Tags.CATEGORICAL, Tags.LIST)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,16.0,8.0,.//categories/unique.token_rank_category.parquet,0.0,0.0,7.0,token_rank_category,2.0,20.0
9,risky_flags-list,"(Tags.CATEGORICAL, Tags.LIST)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,16.0,11.0,.//categories/unique.risky_flags.parquet,0.0,0.0,10.0,risky_flags,2.0,20.0


In [5]:
inputs = tr.TabularSequenceFeatures.from_schema(
        schema,
        max_sequence_length=20,
        continuous_projection=64,
        masking="mlm",
        d_output=100,
)

In [6]:
# Define XLNetConfig class and set default parameters for HF XLNet config  
transformer_config = tr.XLNetConfig.build(
    d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)

# Define the evaluation top-N metrics and the cut-offs
metrics = [NDCGAt(top_ks=[5, 10,20], labels_onehot=True),  
           RecallAt(top_ks=[5, 10,20], labels_onehot=True),
           AvgPrecisionAt(top_ks=[5,10,20], labels_onehot=True)]

# Define a head related to next item prediction task 
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, 
                              metrics=metrics),
    inputs=inputs,
)

# Get the end-to-end Model class 
model = tr.Model(head)

In [7]:
per_device_train_batch_size = int(os.environ.get(
    "per_device_train_batch_size", 
    '256'
))

per_device_eval_batch_size = int(os.environ.get(
    "per_device_eval_batch_size", 
    '256'
))

In [8]:

from transformers4rec.config.trainer import T4RecTrainingArguments
from transformers4rec.torch import Trainer
# Set hyperparameters for training 
train_args = T4RecTrainingArguments(data_loader_engine='merlin', 
                                    dataloader_drop_last = True,
                                    gradient_accumulation_steps = 1,
                                    per_device_train_batch_size = per_device_train_batch_size, 
                                    per_device_eval_batch_size = per_device_eval_batch_size,
                                    output_dir = "/home/ec2-user/SageMaker/model/", 
                                    learning_rate=0.0005,
                                    lr_scheduler_type='cosine', 
                                    learning_rate_num_cosine_cycles_by_epoch=1.5,
                                    num_train_epochs=10,
                                    max_sequence_length=20, 
                                    report_to = [],
                                    logging_steps=50,
                                    no_cuda=False)

In [9]:
trainer = Trainer(
    model=model,
    args=train_args,
    schema=schema,
    compute_metrics=True,
)

In [10]:
train_paths = train
print(train_paths)

# Train on day related to time_index 
print('*'*20)
print('*'*20 + '\n')
trainer.train_dataset_or_path = train_paths
trainer.reset_lr_scheduler()
trainer.train()
print('finished')



<merlin.io.dataset.Dataset object at 0x7ff6e40bd210>
********************
********************





Step,Training Loss
50,6.4444
100,5.2768
150,4.9634
200,4.9347
250,4.89
300,4.8522
350,4.7791
400,4.7478
450,4.6893
500,4.6743


finished


In [11]:
eval_paths = os.path.join(INPUT_DATA_DIR, "202202-202204/data.parquet/part_0.parquet")

# Evaluate on the following day
trainer.eval_dataset_or_path = eval_paths
train_metrics = trainer.evaluate(metric_key_prefix='eval')
print('*'*20)
print('\n' + '*'*20 + '\n')
for key in sorted(train_metrics.keys()):
    print(" %s = %s" % (key, str(train_metrics[key]))) 
wipe_memory()

********************

********************

 eval_/loss = 4.497103691101074
 eval_/next-item/avg_precision_at_10 = 0.3204691708087921
 eval_/next-item/avg_precision_at_20 = 0.3257555365562439
 eval_/next-item/avg_precision_at_5 = 0.30787158012390137
 eval_/next-item/ndcg_at_10 = 0.38628265261650085
 eval_/next-item/ndcg_at_20 = 0.40538936853408813
 eval_/next-item/ndcg_at_5 = 0.35615113377571106
 eval_/next-item/recall_at_10 = 0.594357430934906
 eval_/next-item/recall_at_20 = 0.6696138978004456
 eval_/next-item/recall_at_5 = 0.5020332336425781
 eval_runtime = 10.8517
 eval_samples_per_second = 10922.52
 eval_steps_per_second = 42.666


In [12]:
model_path= os.environ.get("OUTPUT_DIR", f"{INPUT_DATA_DIR}/saved_model")
model.save(model_path)