Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EncoderDecoder] Add encoder-decoder for roberta/ vanilla longformer #6411

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@
from .tokenization_marian import MarianTokenizer
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaForCausalLM,
RobertaModel,
RobertaForSequenceClassification,
RobertaForMultipleChoice,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
)
from .modeling_retribert import RetriBertModel
from .modeling_roberta import (
RobertaForCausalLM,
RobertaForMaskedLM,
RobertaForMultipleChoice,
RobertaForQuestionAnswering,
Expand Down Expand Up @@ -248,6 +249,7 @@

MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
[
(RobertaConfig, RobertaForCausalLM),
(BertConfig, BertLMHeadModel),
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
(GPT2Config, GPT2LMHeadModel),
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def __init__(self, config):
super().__init__(config)

if not config.is_decoder:
logger.info("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
Expand All @@ -976,9 +976,9 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
Comment on lines -979 to +983
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

output_attentions=None,
output_hidden_states=None,
return_dict=None,
Expand Down Expand Up @@ -1061,8 +1061,8 @@ def __init__(self, config):
super().__init__(config)

if config.is_decoder:
logger.info(
"If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
logger.warning(
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)

Expand All @@ -1089,9 +1089,9 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Expand Down
136 changes: 135 additions & 1 deletion src/transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
from torch.nn import CrossEntropyLoss, MSELoss

from .configuration_roberta import RobertaConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
from .modeling_outputs import (
CausalLMOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -139,6 +145,14 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
Expand Down Expand Up @@ -175,6 +189,116 @@ def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value


@add_start_docstrings(
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
)
class RobertaForCausalLM(BertPreTrainedModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more coherent to have it as RobertaLMHeadModel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But I do prefer RobertaForCausalLM)

Copy link
Collaborator

@sgugger sgugger Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same names for BERT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following internal discussion will leave the name as it is more precise and BertLMHeadModel should change in the future.

config_class = RobertaConfig
base_model_prefix = "roberta"

def __init__(self, config):
super().__init__(config)

if not config.is_decoder:
logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")

self.roberta = RobertaModel(config)
self.lm_head = RobertaLMHead(config)

self.init_weights()

def get_output_embeddings(self):
return self.lm_head.decoder

@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using those kwargs? If so change the docstrings since there is no legacy arguments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed them from the corresponding BERT model as well.

):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Used to hide legacy arguments that have been deprecated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this I think

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

Returns:

Example::

>>> from transformers import RobertaTokenizer, RobertaLMHeadModel, RobertaConfig
>>> import torch

>>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
>>> config = RobertaConfig.from_pretrained("roberta-base")
>>> config.is_decoder = True
>>> model = RobertaLMHeadModel.from_pretrained('roberta-base', config=config, return_dict=True)

>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)

>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {"input_ids": input_ids, "attention_mask": attention_mask}


@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class RobertaForMaskedLM(BertPreTrainedModel):
config_class = RobertaConfig
Expand All @@ -183,6 +307,12 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

if config.is_decoder:
logger.warning(
"If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)

self.roberta = RobertaModel(config)
self.lm_head = RobertaLMHead(config)

Expand All @@ -206,6 +336,8 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -237,6 +369,8 @@ def forward(
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

if config.is_decoder:
logger.info(
logger.warning(
"If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
Expand Down Expand Up @@ -941,7 +941,7 @@ def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

if not config.is_decoder:
logger.info("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")

self.bert = TFBertMainLayer(config, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
Expand Down
Loading