Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5] allow config.decoder_layers to control decoder size #7409

Merged
merged 5 commits into from Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/configuration_t5.py
Expand Up @@ -57,6 +57,8 @@ class T5Config(PretrainedConfig):
Size of the intermediate feed forward layer in each :obj:`T5Block`.
num_layers (:obj:`int`, `optional`, defaults to 6):
Number of hidden layers in the Transformer encoder.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was previously documented incorrectly. Now it is correct!

decoder_layers (:obj:`int`, `optional`, defaults to num_layers):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documented defaults should match the signature. It will eventually default to num_layers and this should be said, but in the line after, suggestion:

        decoder_layers (:obj:`int`, `optional`):
            Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not set.

Number of hidden layers in the Transformer decoder.
num_heads (:obj:`int`, `optional`, defaults to 8):
Number of attention heads for each attention layer in
the Transformer encoder.
Expand All @@ -80,6 +82,7 @@ def __init__(
d_kv=64,
d_ff=2048,
num_layers=6,
decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
Expand All @@ -102,6 +105,7 @@ def __init__(
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.decoder_layers = decoder_layers if decoder_layers is not None else self.num_layers # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/modeling_t5.py
Expand Up @@ -919,6 +919,7 @@ def __init__(self, config):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)

self.init_weights()
Expand Down Expand Up @@ -1077,6 +1078,7 @@ def __init__(self, config):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)

self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_modeling_t5.py
Expand Up @@ -59,6 +59,7 @@ def __init__(
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):

self.parent = parent
Expand All @@ -83,6 +84,7 @@ def __init__(
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
Expand All @@ -105,6 +107,7 @@ def prepare_config_and_inputs(self):
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
Expand Down Expand Up @@ -623,3 +626,40 @@ def test_translation_en_to_ro(self):
output = model.generate(**inputs)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation)


@require_torch
class TestAsymmetricT5(unittest.TestCase):
def build_model_and_check_forward_pass(self, **kwargs):
tester = T5ModelTester(self, **kwargs)
config, *inputs = tester.prepare_config_and_inputs()
(
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = inputs
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
# outputs = model(*inputs)
assert len(outputs) == 4
assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size)
assert outputs["loss"].size() == ()
return model

def test_small_decoder(self):
# num_hidden_layers is passed to T5Config as num_layers
model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2)
assert len(model.encoder.block) == 2
assert len(model.decoder.block) == 1

def test_defaulting_to_symmetry(self):
# num_hidden_layers is passed to T5Config as num_layers
model = self.build_model_and_check_forward_pass(num_hidden_layers=2)
assert len(model.decoder.block) == len(model.encoder.block) == 2