Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile #5395

Merged
merged 32 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8cf4476
add first version of clm tf
patrickvonplaten Jun 30, 2020
59216aa
make style
patrickvonplaten Jun 30, 2020
29570d7
add more tests for bert
patrickvonplaten Jun 30, 2020
2ea1100
update tf clm loss
patrickvonplaten Jun 30, 2020
df7a948
fix tests
patrickvonplaten Jun 30, 2020
cceed5d
correct tf ner script
patrickvonplaten Jul 1, 2020
66c0ede
add mlm loss
patrickvonplaten Jul 1, 2020
2960946
delete bogus file
patrickvonplaten Jul 2, 2020
b49767e
clean tf auto model + add tests
patrickvonplaten Jul 2, 2020
06a3b62
finish adding clm loss everywhere
patrickvonplaten Jul 2, 2020
7056d93
fix training in distilbert
patrickvonplaten Jul 2, 2020
ae543c0
fix flake8
patrickvonplaten Jul 2, 2020
f439b78
save intermediate
patrickvonplaten Jul 2, 2020
298b7a4
fix tf t5 naming
patrickvonplaten Jul 3, 2020
ad8a84c
remove prints
patrickvonplaten Jul 3, 2020
65e9a63
finish up
patrickvonplaten Jul 3, 2020
5795c19
up
patrickvonplaten Jul 3, 2020
794697a
Merge branch 'master' into add_tf_clm_loss
patrickvonplaten Jul 3, 2020
ac4e2b1
fix tf gpt2
patrickvonplaten Jul 3, 2020
4bc8413
Merge branch 'add_tf_clm_loss' of https://github.com/patrickvonplaten…
patrickvonplaten Jul 3, 2020
c75b906
fix new test utils import
patrickvonplaten Jul 3, 2020
8164122
fix flake8
patrickvonplaten Jul 3, 2020
05bf021
keep backward compatibility
patrickvonplaten Jul 3, 2020
24728dd
Update src/transformers/modeling_tf_albert.py
patrickvonplaten Jul 7, 2020
5c0b471
Update src/transformers/modeling_tf_auto.py
patrickvonplaten Jul 7, 2020
ba59316
Update src/transformers/modeling_tf_electra.py
patrickvonplaten Jul 7, 2020
e9abed9
Update src/transformers/modeling_tf_roberta.py
patrickvonplaten Jul 7, 2020
c80a727
Update src/transformers/modeling_tf_mobilebert.py
patrickvonplaten Jul 7, 2020
d3ddcf2
Update src/transformers/modeling_tf_auto.py
patrickvonplaten Jul 7, 2020
ae4dcbd
Update src/transformers/modeling_tf_bert.py
patrickvonplaten Jul 7, 2020
c5cbae1
Update src/transformers/modeling_tf_distilbert.py
patrickvonplaten Jul 7, 2020
c25aa53
apply sylvains suggestions
patrickvonplaten Jul 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/token-classification/run_tf_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[L

for i in range(batch_size):
for j in range(seq_len):
if label_ids[i, j] != -1:
if label_ids[i, j] != -100:
out_label_list[i].append(label_map[label_ids[i][j]])
preds_list[i].append(label_map[preds[i][j]])

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,19 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TFAutoModel,
TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
TFAutoModelForSeq2SeqLM,
)

from .modeling_tf_albert import (
Expand All @@ -446,6 +452,7 @@
from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
TFBertLMHeadModel,
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from .modeling_camembert import (
CamembertForMaskedLM,
CamembertForMultipleChoice,
CamembertForQuestionAnswering,
CamembertForSequenceClassification,
CamembertForTokenClassification,
CamembertModel,
Expand Down Expand Up @@ -300,6 +301,7 @@
[
(DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering),
(CamembertConfig, CamembertForQuestionAnswering),
(BartConfig, BartForQuestionAnswering),
(LongformerConfig, LongformerForQuestionAnswering),
(XLMRobertaConfig, XLMRobertaForQuestionAnswering),
Expand Down Expand Up @@ -329,7 +331,6 @@
]
)


MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(CamembertConfig, CamembertForMultipleChoice),
Expand Down
48 changes: 44 additions & 4 deletions src/transformers/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
Expand Down Expand Up @@ -822,7 +823,7 @@ def call(self, pooled_output, training: bool):


@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -834,8 +835,26 @@ def get_output_embeddings(self):

@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``

Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`
Expand All @@ -852,14 +871,35 @@ def call(self, inputs, **kwargs):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
outputs = self.albert(inputs, **kwargs)
if isinstance(inputs, (tuple, list)):
labels = inputs[8] if len(inputs) > 8 else labels
if len(inputs) > 8:
inputs = inputs[:8]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)

outputs = self.albert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

sequence_output = outputs[0]
prediction_scores = self.predictions(sequence_output, training=kwargs.get("training", False))
prediction_scores = self.predictions(sequence_output, training=training)

# Add hidden states and attention if they are here
outputs = (prediction_scores,) + outputs[2:]

if labels is not None:
loss = self.compute_loss(labels, prediction_scores)
outputs = (loss,) + outputs

return outputs # prediction_scores, (hidden_states), (attentions)


Expand Down
Loading