Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EncoderDecoderModel] add a add_cross_attention boolean to config #6377

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/configuration_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')
Expand Down Expand Up @@ -94,8 +95,9 @@ def from_encoder_decoder_configs(
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
logger.info("Set `config.is_decoder=True` for decoder_config")
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class PretrainedConfig(object):
Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models withith the `EncoderDecoderModel` class, which consists of all models in `AUTO_MODELS_FOR_CAUSAL_LM`.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer.
Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(self, **kwargs):
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)

# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
Expand All @@ -399,6 +401,9 @@ def forward(
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
Expand Down Expand Up @@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel):
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

To behave as an decoder the model needs to be initialized with the
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
:obj:`is_decoder` argument of the configuration set to :obj:`True`.
To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
argument and :obj:`add_cross_attention` set to :obj:`True`;an
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
:obj:`encoder_hidden_states` is then expected as an input to the forward pass.

.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,18 @@ def from_encoder_decoder_pretrained(
from .configuration_auto import AutoConfig

decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False:
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

kwargs_decoder["config"] = decoder_config

if kwargs_decoder["config"].is_decoder is False:
if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)

decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def create_and_check_bert_model_as_decoder(
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = BertModel(config)
model.to(torch_device)
model.eval()
Expand Down Expand Up @@ -235,6 +236,7 @@ def create_and_check_bert_model_for_causal_lm_as_decoder(
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
Expand Down
6 changes: 5 additions & 1 deletion tests/test_modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def prepare_config_and_inputs_bert(self):
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs

# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
Expand Down Expand Up @@ -119,6 +122,7 @@ def create_and_check_bert_encoder_decoder_model(
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
Expand Down Expand Up @@ -330,7 +334,7 @@ def test_real_bert_model_from_pretrained(self):
self.assertIsNotNone(model)

@slow
def test_real_bert_model_from_pretrained_has_cross_attention(self):
def test_real_bert_model_from_pretrained_add_cross_attention(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))

Expand Down