In [1]:
from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt

# Create a schema or read one from disk: tr.Schema().from_json(SCHEMA_PATH).
schema: tr.Schema = tr.data.tabular_sequence_testing_data.schema

max_sequence_length, d_model = 20, 64

# Define the input module to process the tabular input features.
input_module = tr.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=max_sequence_length,
    continuous_projection=d_model,
    aggregation="concat",
    masking="causal",
)

# Define a transformer-config like the XLNet architecture.
transformer_config = tr.XLNetConfig.build(
    d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length
)

# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    input_module,
    tr.MLPBlock([d_model]),
    tr.TransformerBlock(transformer_config, masking=input_module.masking)
)

# Define the evaluation top-N metrics and the cut-offs
metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True),
           RecallAt(top_ks=[20, 40], labels_onehot=True)]

# Define a head with NextItemPredictionTask.
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, metrics=metrics),
    inputs=input_module,
)

# Get the end-to-end Model class.
model = tr.Model(head)

  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")
  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import transformers4rec

In [7]:
schema

[{'name': 'timestamp/age_days/LogOp/Normalize/list', 'value_count': {'min': '2', 'max': '185'}, 'type': 'FLOAT', 'float_domain': {'name': 'timestamp/age_days/LogOp/Normalize/list', 'min': -2.917729139328003, 'max': 1.5231701135635376}, 'annotation': {'tag': ['continuous', 'list']}}, {'name': 'timestamp/hour/list', 'value_count': {'min': '2', 'max': '185'}, 'type': 'FLOAT', 'float_domain': {'name': 'timestamp/hour/list', 'min': 5.7866054703481495e-06, 'max': 1.605135440826416}, 'annotation': {'tag': ['continuous', 'time', 'list']}}, {'name': 'timestamp/weekday/list', 'value_count': {'min': '2', 'max': '185'}, 'type': 'FLOAT', 'float_domain': {'name': 'timestamp/weekday/list', 'min': 0.00013345430488698184, 'max': 1.568290114402771}, 'annotation': {'tag': ['continuous', 'time', 'list']}}, {'name': 'timestamp/day/list', 'value_count': {'min': '2', 'max': '185'}, 'type': 'FLOAT', 'float_domain': {'name': 'timestamp/day/list', 'min': 0.055881768465042114, 'max': 1.7342302799224854}, 'annota

In [10]:
transformers4rec.torch.TabularModule(

)

transformers4rec.torch.tabular.base.TabularModule

In [2]:
model

Model(
  (heads): ModuleList(
    (0): Head(
      (body): SequentialBlock(
        (0): TabularSequenceFeatures(
          (_aggregation): ConcatFeatures()
          (to_merge): ModuleDict(
            (continuous_module): SequentialBlock(
              (0): ContinuousFeatures(
                (filter_features): FilterFeatures()
                (_aggregation): ConcatFeatures()
              )
              (1): SequentialBlock(
                (0): DenseBlock(
                  (0): Linear(in_features=11, out_features=64, bias=True)
                  (1): ReLU(inplace=True)
                )
              )
              (2): AsTabular()
            )
            (categorical_module): SequenceEmbeddingFeatures(
              (filter_features): FilterFeatures()
              (embedding_tables): ModuleDict(
                (item_id/list): Embedding(51997, 64, padding_idx=0)
                (category/list): Embedding(333, 64, padding_idx=0)
                (user_country): Embedding(63, 6

In [6]:
from transformers4rec import torch as tconf

# Define schema to include text embedding feature
schema = tconf.Schema([tconf.FeatureConfig(name="user_id"),
                       tconf.FeatureConfig(name="item_id"),
                       tconf.FeatureConfig(name="text_embeddings",
                                           embedding_dim=text_embeddings.shape[1])])

TypeError: FeatureConfig.__init__() missing 1 required positional argument: 'table'