Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile #5395

Merged
merged 32 commits into from
Jul 7, 2020

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jun 30, 2020

This PR aligns TF code more with PT code and adds full training support to all CLM and MLM models applying @jplu's loss design to the remaining models. In more detail the following things are included in the PR:

  • Add TFMaskedLanguageModelingLoss and TFCausalLanguageModelingLoss to all CLM and MLM TF models. Only Transfo-XL and XLM are not included since they use adaptive softmax (TF Transfo-XL currently has no Adaptive Softmax implemented cc @TevenLeScao for notification)
  • Change value to mask CE loss from -1 to -100 to align with PyTorch cc - tf_ner script is updated accordingly @jplu. Using -1 is deprecated here and should be removed in a future version.
  • Split Bert into BertForCLM and BertForMLM as was done in PyTorch (small break in backward compatibility here)
  • Split TFAutoModelWithLMHead into TFAutoModelForCLM, ...ForMLM, ForSeq2Seq as was done in PyTorch to make TF ready for encoder-decoder wrapper.
  • Add various tests for modeling_tf_auto.py e.g. that the mappings are correctly ordered
  • Fix inconsistent naming in TF T5 and fix TF T5 keras compilation bug @sshleifer - encoder decoder tf related tests are fixed so should concern tf bart as well

TODO:

  • add labels to all tests where it applies
  • add CLM loss to all other models
  • add MLM loss to all other models
  • Clean TF T5

Future Pr:

  • Test that TF Trainer works well with all new CLM / MLM models - we should definitely start adding tests for TF Trainer as well @jplu @julien-c @LysandreJik
  • TF Benchmark can now be done on training as welll -> update the benchmark scripts

@@ -843,6 +847,80 @@ def call(self, inputs, **kwargs):
return outputs # prediction_scores, (hidden_states), (attentions)


class TFBertLMHeadModel(TFBertForPreTraining, TFCausalLanguageModelingLoss):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @thomwolf split Bert into two as was done for PyTorch -> small break in backward compatibility here.

class TFCausalLanguageModelingLoss:
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the reduction to Reduction.NONE as was done for the other losses -> is that correct @LysandreJik @jplu ?

Copy link
Contributor

@jplu jplu Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, beause NONE is for making the loss computation compliant with the custom training. And the reduction is let to the trainer.


outputs = (logits,) + outputs[2:] # Add hidden states and attention if they are here
if labels is not None:
logits = logits[: :-1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift logits the same it's done in PyTorch

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jun 30, 2020

@LysandreJik @julien-c @jplu @thomwolf @sgugger - can you take a look at this example if the CLM loss is correctly added? If yes, I will add this loss to all other CLM models and add tests.

@jplu
Copy link
Contributor

jplu commented Jun 30, 2020

Looks good to me!!

@codecov
Copy link

codecov bot commented Jun 30, 2020

Codecov Report

Merging #5395 into master will decrease coverage by 0.03%.
The diff coverage is 77.17%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #5395      +/-   ##
==========================================
- Coverage   76.39%   76.35%   -0.04%     
==========================================
  Files         141      141              
  Lines       24617    24868     +251     
==========================================
+ Hits        18807    18989     +182     
- Misses       5810     5879      +69     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.22% <ø> (ø)
src/transformers/modeling_auto.py 74.41% <ø> (ø)
src/transformers/modeling_xlnet.py 78.86% <ø> (ø)
src/transformers/modeling_tf_electra.py 26.02% <10.00%> (-0.91%) ⬇️
src/transformers/modeling_tf_mobilebert.py 23.38% <14.28%> (-0.24%) ⬇️
src/transformers/modeling_tf_auto.py 63.03% <40.42%> (-9.47%) ⬇️
src/transformers/modeling_tf_bert.py 96.97% <80.00%> (-1.40%) ⬇️
src/transformers/modeling_tf_t5.py 90.90% <83.65%> (-0.53%) ⬇️
src/transformers/modeling_tf_utils.py 88.88% <84.61%> (-0.23%) ⬇️
src/transformers/modeling_tf_albert.py 76.47% <100.00%> (+0.48%) ⬆️
... and 9 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 21cd8c4...c25aa53. Read the comment docs.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great addition!

Comment on lines 875 to 877
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The additional argument labels should be added here

@patrickvonplaten
Copy link
Contributor Author

Ok will add this for all TF CLM models then :-) and add tests.

)
# make sure only labels that are not equal to -100
# are taken into account as loss
active_loss = tf.reshape(labels, (-1,)) != -100
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jplu @sgugger @LysandreJik - In PyTorch we use -100 and not -1 to mask tokens for the loss. Should we do the same here?
Would slightly break backward compatibility since -1 was already used for token classification - but not sure how many people already trained on token classification.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think -100 would be the most rigorous, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can release 3.0.1 immediately so that nearly no users are affected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to have a consistent value there. Before it was only used for TokenClassification and we don't have any notebooks/ examples on TF token classification training so not too many people should be affected. I think it's worth it to align the values here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for consistency

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to replace for consistency! but don't forget to update run_tf_ner.py and TFTokenClassificationLoss accordingly as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it might be better to deprecate -1 here and still allow its usage for backward compatibility no? @sgugger @LysandreJik @jplu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Good idea.

@@ -122,7 +135,9 @@ def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
active_loss = tf.reshape(labels, (-1,)) != -1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jplu @sgugger @LysandreJik - In PyTorch we use -100 and not -1 to mask tokens for the loss. Should we do the same here?
Would slightly break backward compatibility since -1 was already used for token classification - but not sure how many people already trained on token classification.

@patrickvonplaten patrickvonplaten changed the title [TF all CLM models] provide labels to forward for tf [Don't merge yet] [TF all CLM models] provide labels to forward for tf Jun 30, 2020
@patrickvonplaten patrickvonplaten changed the title [Don't merge yet] [TF all CLM models] provide labels to forward for tf [Don't merge yet] [TF all CLM / MLM models] provide labels to forward for tf Jun 30, 2020

hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits)
prediction_logits = self.vocab_projector(prediction_logits, training=training)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @VictorSanh - think this training=training was missing here when comparing to tf_bert. Not 100p sure though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's necessary, as self.vocab_projector is a linear layer. I believe the training parameter is only useful for dropout?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! training=training is only relevant if there is a dropout or batchnorm keras layer: tensorflow/tensorflow#36936

@@ -94,6 +95,8 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
elif model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.values():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love these tests whoever made them @jplu

return outputs # (loss), prediction_scores, (hidden_states), (attentions)


class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split Bert the same way it's done in PyTorch


hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits, training=training) # (bs, seq_length, dim)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think layernorm weights should be conditioned on the training parameter in TF Keras

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't hurt but I don't see it used in the source code.

output_attentions=None,
output_hidden_states=None,
labels=None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just move labels to its correct position

# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
# model = DistilBertModel.from_pretrained(model_name)
# self.assertIsNotNone(model)
@slow
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable slow test

@@ -25,6 +25,8 @@
from transformers import (
AutoConfig,
BertConfig,
GPT2Config,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add missing tests from its pytorch version

@@ -292,7 +301,7 @@ def test_compile_tf_model(self):
"decoder_input_ids": tf.keras.Input(
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
),
"inputs": tf.keras.Input(batch_shape=(2, 2000), name="inputs", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix T5 inputs vs input_ids @sshleifer

if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
# if loss is causal lm loss, labels are shift, so that one label per batch
# is cut
loss_size = loss_size - self.model_tester.batch_size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky, but I don't see another way around it at the moment @sgugger @jplu - CLM loss shifts the tokens and thus cuts off one token

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it is ok to do like this, it doesn't seems too odd.


if isinstance(inputs, dict):
kwargs.update(inputs)
if isinstance(inputs, (tuple, list)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were was one inconsistency in TF T5 before in that the variable input_ids was wrongly called inputs @sshleifer .
Also TF T5 is made completely keras compilation compatible here which was not the case before.

@patrickvonplaten patrickvonplaten changed the title [Don't merge yet] [TF all CLM / MLM models] provide labels to forward for tf TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile Jul 3, 2020
@patrickvonplaten patrickvonplaten changed the title TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile [Almost all TF models] TF clean up: add missing CLM / MLM loss; fix T5 naming and keras compile Jul 3, 2020
@@ -140,126 +142,158 @@

TF_MODEL_MAPPING = OrderedDict(
[
(T5Config, TFT5Model),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordered all the mappings here the same way it's done in PyTorch and added a test to check it's correct (also cc @Pierrci - think you reordered Roberta here recently).

)

# insert decoder past at right place
# to speed up decoding
if use_cache is True:
if cast_bool_to_primitive(use_cache) is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest cast_bool_to_primitive(use_cache, True) is True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast_bool_to_primitive(use_cache, self.use_cache) is True

as you did in your other PR is actually much cleaner :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Just a few suggestions for docstrings (missing TF or tf.Tensor).

src/transformers/modeling_tf_albert.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_auto.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_auto.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_auto.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_bert.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_distilbert.py Outdated Show resolved Hide resolved

hidden_states = distilbert_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits, training=training) # (bs, seq_length, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't hurt but I don't see it used in the source code.

src/transformers/modeling_tf_electra.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_mobilebert.py Outdated Show resolved Hide resolved
src/transformers/modeling_tf_roberta.py Outdated Show resolved Hide resolved
patrickvonplaten and others added 9 commits July 7, 2020 17:42
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@patrickvonplaten patrickvonplaten merged commit 4dc6559 into huggingface:master Jul 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add "labels" functionality for all TF Causal LM and Masked LM models
4 participants