Skip to content

Commit

Permalink
Adding tests for various transformer configs (NVIDIA-Merlin#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jul 21, 2021
1 parent 5319135 commit 40149f1
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 6 deletions.
30 changes: 30 additions & 0 deletions tests/config/test_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from transformers import PreTrainedModel, TFPreTrainedModel

from transformers4rec.config import transformer as tconf

config_classes = [
tconf.ReformerConfig,
tconf.XLNetConfig,
tconf.ElectraConfig,
tconf.LongformerConfig,
tconf.GPT2Config,
]


@pytest.mark.parametrize("config_cls", config_classes)
def test_to_torch_model(config_cls):
config = config_cls.for_rec(100, 4, 2, 20)

model = config.to_torch_model()

assert isinstance(model, PreTrainedModel)


@pytest.mark.parametrize("config_cls", list(set(config_classes) - {tconf.ReformerConfig}))
def test_to_tf_model(config_cls):
config = config_cls.for_rec(100, 4, 2, 20)

model = config.to_tf_model()

assert isinstance(model, TFPreTrainedModel)
220 changes: 216 additions & 4 deletions transformers4rec/config/transformer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,226 @@
import transformers


class T4RecConfig:
def to_torch_model(self):
from transformers import MODEL_MAPPING
model_cls = transformers.MODEL_MAPPING[self.transformers_config_cls]

return MODEL_MAPPING[self]
return model_cls(self)

def to_tf_model(self):
from transformers import TF_MODEL_MAPPING
model_cls = transformers.TF_MODEL_MAPPING[self.transformers_config_cls]

return model_cls(self)

return TF_MODEL_MAPPING[self]
@property
def transformers_config_cls(self):
return self.__class__.__bases__[1]

@classmethod
def for_rec(cls, *args, **kwargs):
raise NotImplementedError


class ReformerConfig(T4RecConfig, transformers.ReformerConfig):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
total_seq_length,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
attention_head_size=d_model,
attn_layers=["local", "lsh"] * (n_layer // 2) if n_layer > 2 else ["local"],
feed_forward_size=d_model * 4,
hidden_size=d_model,
num_attention_heads=n_head,
hidden_act=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=dropout,
lsh_attention_probs_dropout_prob=dropout,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
max_position_embeddings=total_seq_length,
axial_pos_embds_dim=[
d_model // 2,
d_model // 2,
],
vocab_size=1,
**kwargs
)


class GPT2Config(T4RecConfig, transformers.GPT2Config):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
total_seq_length,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
n_embd=d_model,
n_inner=d_model * 4,
n_layer=n_layer,
n_head=n_head,
activation_function=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
resid_pdrop=dropout,
embd_pdrop=dropout,
attn_pdrop=dropout,
n_positions=total_seq_length,
n_ctx=total_seq_length,
output_attentions=log_attention_weights,
vocab_size=1,
**kwargs
)


class LongformerConfig(T4RecConfig, transformers.LongformerConfig):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
total_seq_length,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
hidden_size=d_model,
num_hidden_layers=n_layer,
num_attention_heads=n_head,
hidden_act=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
dropout=dropout,
max_position_embeddings=total_seq_length,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
vocab_size=1,
**kwargs
)


class ElectraConfig(T4RecConfig, transformers.ElectraConfig):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
total_seq_length,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
hidden_size=d_model,
embedding_size=d_model,
num_hidden_layers=n_layer,
num_attention_heads=n_head,
intermediate_size=d_model * 4,
hidden_act=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=dropout,
max_position_embeddings=total_seq_length,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
vocab_size=1,
**kwargs
)


class AlbertConfig(T4RecConfig, transformers.AlbertConfig):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
total_seq_length,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
hidden_size=d_model,
num_attention_heads=n_head,
num_hidden_layers=n_layer,
hidden_act=hidden_act,
intermediate_size=d_model * 4,
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=dropout,
max_position_embeddings=total_seq_length,
embedding_size=d_model, # should be same as dimension of the input to ALBERT
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
output_attentions=log_attention_weights,
vocab_size=1,
**kwargs
)


class XLNetConfig(T4RecConfig, transformers.XLNetConfig):
@classmethod
def for_rec(
cls,
d_model,
n_head,
n_layer,
hidden_act="gelu",
initializer_range=0.01,
layer_norm_eps=0.03,
dropout=0.3,
pad_token=0,
log_attention_weights=True,
**kwargs
):
return cls(
d_model=d_model,
d_inner=d_model * 4,
n_layer=n_layer,
n_head=n_head,
ff_activation=hidden_act,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
dropout=dropout,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
vocab_size=1,
**kwargs
)
4 changes: 2 additions & 2 deletions transformers4rec/torch/block/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from transformers import PreTrainedModel

from ..transformer import T4RecConfig
from ...config.transformer import T4RecConfig
from ..typing import MaskedSequence, MaskSequence, ProcessedSequence
from .base import BuildableBlock, SequentialBlock

Expand All @@ -17,7 +17,7 @@ def __init__(
) -> None:
super().__init__()
if isinstance(body, T4RecConfig):
body = body.to_model()
body = body.to_torch_model()

self.masking = masking
self.body = body
Expand Down

0 comments on commit 40149f1

Please sign in to comment.