Skip to content

Commit

Permalink
[Tests] Add Common Test for Training + Fix a couple of bugs (#8415)
Browse files Browse the repository at this point in the history
* add training tests

* correct longformer

* fix docs

* fix some tests

* fix some more train tests

* remove ipdb

* fix multiple edge case model training

* fix funnel and prophetnet

* clean gpt models

* undo renaming of albert
  • Loading branch information
patrickvonplaten committed Nov 9, 2020
1 parent 5204051 commit 9c83b96
Show file tree
Hide file tree
Showing 30 changed files with 445 additions and 34 deletions.
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ AutoModelForMultipleChoice
:members:


AutoModelForNextSentencePrediction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.AutoModelForNextSentencePrediction
:members:


AutoModelForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion examples/lxmert/modeling_frcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ For further information or requests, please go to [BERTimbau repository](https:/

```python
from transformers import AutoTokenizer # Or BertTokenizer
from transformers import AutoModelForPretraining # Or BertForPreTraining for loading pretraining heads
from transformers import AutoModelForPreTraining # Or BertForPreTraining for loading pretraining heads
from transformers import AutoModel # or BertModel, for BERT without pretraining heads

model = AutoModelForPreTraining.from_pretrained('neuralmind/bert-base-portuguese-cased')
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
Expand All @@ -340,6 +341,7 @@
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
Expand Down
116 changes: 115 additions & 1 deletion src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from .modeling_bert import (
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
Expand Down Expand Up @@ -128,6 +129,7 @@
from .modeling_funnel import (
FunnelForMaskedLM,
FunnelForMultipleChoice,
FunnelForPreTraining,
FunnelForQuestionAnswering,
FunnelForSequenceClassification,
FunnelForTokenClassification,
Expand All @@ -143,12 +145,13 @@
LongformerForTokenClassification,
LongformerModel,
)
from .modeling_lxmert import LxmertForPreTraining, LxmertModel
from .modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from .modeling_marian import MarianMTModel
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
MobileBertForNextSentencePrediction,
MobileBertForPreTraining,
MobileBertForQuestionAnswering,
MobileBertForSequenceClassification,
Expand All @@ -166,6 +169,7 @@
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
ReformerForSequenceClassification,
ReformerModel,
ReformerModelWithLMHead,
)
Expand Down Expand Up @@ -285,6 +289,7 @@
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForPreTraining),
(LxmertConfig, LxmertForPreTraining),
(FunnelConfig, FunnelForPreTraining),
]
)

Expand Down Expand Up @@ -396,6 +401,7 @@
(DebertaConfig, DebertaForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
]
)

Expand All @@ -417,6 +423,7 @@
(ElectraConfig, ElectraForQuestionAnswering),
(ReformerConfig, ReformerForQuestionAnswering),
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
]
)

Expand Down Expand Up @@ -460,6 +467,13 @@
]
)

MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
[
(BertConfig, BertForNextSentencePrediction),
(MobileBertConfig, MobileBertForNextSentencePrediction),
]
)

AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either
Expand Down Expand Up @@ -1519,3 +1533,103 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
", ".join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
)
)


class AutoModelForNextSentencePrediction:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a
multiple choice classification head---when created with the when created with the
:meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the
:meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""

def __init__(self):
raise EnvironmentError(
"AutoModelForNextSentencePrediction is designed to be instantiated "
"using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForNextSentencePrediction.from_config(config)` methods."
)

@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with a multiple choice classification head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load
the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download configuration from S3 and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForNextSentencePrediction.from_config(config)
"""
if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)

raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)

@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with a multiple choice classification head---from a "
"pretrained model.",
AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download model and configuration from S3 and cache.
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
>>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
>>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)

if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)

raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)
19 changes: 14 additions & 5 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,13 +1228,14 @@ def forward(
position_ids=None,
head_mask=None,
inputs_embeds=None,
next_sentence_label=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
Expand All @@ -1255,10 +1256,18 @@ def forward(
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""

if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
Expand All @@ -1278,9 +1287,9 @@ def forward(
seq_relationship_scores = self.cls(pooled_output)

next_sentence_loss = None
if next_sentence_label is not None:
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def forward(
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.view(-1))
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/modeling_longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def forward(

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return module(*inputs, is_global_attn)

return custom_forward

Expand All @@ -1079,7 +1079,6 @@ def custom_forward(*inputs):
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
else:
layer_outputs = layer_module(
Expand Down
22 changes: 17 additions & 5 deletions src/transformers/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -1154,16 +1155,17 @@ def forward(
visual_attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
masked_lm_labels=None,
labels=None,
obj_labels=None,
matched_label=None,
ans=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Expand All @@ -1183,6 +1185,15 @@ def forward(
Returns:
"""

if "masked_lm_labels" in kwargs:
warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("masked_lm_labels")

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

device = input_ids.device if input_ids is not None else inputs_embeds.device
lxmert_output = self.lxmert(
input_ids=input_ids,
Expand Down Expand Up @@ -1210,13 +1221,13 @@ def forward(

total_loss = (
None
if (masked_lm_labels is None and matched_label is None and obj_labels is None and ans is None)
if (labels is None and matched_label is None and obj_labels is None and ans is None)
else torch.tensor(0.0, device=device)
)
if masked_lm_labels is not None and self.task_mask_lm:
if labels is not None and self.task_mask_lm:
masked_lm_loss = self.loss_fcts["ce"](
lang_prediction_scores.view(-1, self.config.vocab_size),
masked_lm_labels.view(-1),
labels.view(-1),
)
total_loss += masked_lm_loss
if matched_label is not None and self.task_matched:
Expand Down Expand Up @@ -1391,6 +1402,7 @@ def forward(
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

lxmert_output = self.lxmert(
input_ids=input_ids,
Expand Down
Loading

0 comments on commit 9c83b96

Please sign in to comment.