Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow improvements #4530

Merged
merged 50 commits into from Jun 4, 2020
Merged

Tensorflow improvements #4530

merged 50 commits into from Jun 4, 2020

Conversation

jplu
Copy link
Contributor

@jplu jplu commented May 22, 2020

Hello,

Here a quite big PR that propose the following updates:

  • Loss computation is now attached to their respective class, such as PyTorch.
  • Remove useless mode and loss_name parameters for the TF Trainer.
  • Add missing task models to different Transformers
  • Bugfix on T5 keras serialization + tests
  • Add tests for TF Flaubert and XLM-Roberta
  • Bugfix in TF Trainer for Tensorflow 2.2

Reviews are welcome :)

/cc @julien-c @LysandreJik @thomwolf

@codecov-commenter
Copy link

codecov-commenter commented May 22, 2020

Codecov Report

Merging #4530 into master will increase coverage by 0.38%.
The diff coverage is 41.45%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4530      +/-   ##
==========================================
+ Coverage   75.63%   76.01%   +0.38%     
==========================================
  Files         128      128              
  Lines       20979    21417     +438     
==========================================
+ Hits        15867    16280     +413     
- Misses       5112     5137      +25     
Impacted Files Coverage Δ
src/transformers/data/processors/squad.py 28.66% <ø> (ø)
src/transformers/training_args_tf.py 51.16% <ø> (-4.16%) ⬇️
src/transformers/trainer_tf.py 18.86% <17.94%> (+0.94%) ⬆️
src/transformers/modeling_tf_xlm.py 76.10% <27.47%> (-14.30%) ⬇️
src/transformers/modeling_tf_xlnet.py 80.53% <27.50%> (-9.80%) ⬇️
src/transformers/modeling_tf_distilbert.py 82.88% <32.00%> (-12.24%) ⬇️
src/transformers/modeling_tf_roberta.py 74.74% <34.21%> (-25.26%) ⬇️
src/transformers/modeling_tf_electra.py 91.17% <38.70%> (-7.89%) ⬇️
src/transformers/modeling_tf_albert.py 75.39% <45.45%> (-3.30%) ⬇️
src/transformers/modeling_tf_utils.py 87.20% <50.00%> (-1.60%) ⬇️
... and 22 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d976ef2...5b456e2. Read the comment docs.

@jplu
Copy link
Contributor Author

jplu commented May 22, 2020

Some commits are missing... I think it is due to the high number of error rate from Github.

@jplu jplu marked this pull request as draft May 25, 2020 18:41
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
if isinstance(config, config_class):
# Not using isinstance() here to do not take into account inheritance
if config_class == type(config):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in pytorch the different configs are sorted so that you never get a child class before its parent (precisely to prevent this), but this is reasonable solution too

return loss_fn(labels, reduced_logits)


class TFSequenceClassificationAndMultipleChoiceLoss:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we split into two different classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought was, as it is the exact same loss computation why not merge the two names, but your proposal might be more insightful indeed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just alias one to the other or do a trivial sub-class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point, I will do the update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be ok now.

@patrickvonplaten patrickvonplaten self-requested a review May 27, 2020 15:16
@jplu jplu marked this pull request as ready for review May 28, 2020 15:17
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great work, love the added flexibility to the API and how it similar to our Pytorch model's API it can be. I like the coding style.

I find this is a bit different to the PyTorch API as:

  1. it uses Mixins, and I'm okay with them as I think the readability is still good, while it does add a (maybe unnecessary?) layer of abstraction. It does greatly improve code sharing across models though, which is welcome.
  2. Loss isn't computed when passing labels, but by directly calling model.compute_loss(x, y)

I'm not opposed to the first point, but a bit more to the second point. Is there something that prevents using labels in TensorFlow as we do it in PyTorch? As we're aiming at API compatibility, I think this is something we should get right.

Comment on lines -1039 to -1042
print("isdict(1)")
input_ids = inputs.get("input_ids")
print(input_ids)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

@jplu
Copy link
Contributor Author

jplu commented May 29, 2020

Thanks @LysandreJik for your constructive comments!

For the second point, before to answer in order to be sure, you mean that it would be more convenient that the output of the call(...) methods in the TF tasks model returns the same tuple (loss), logits, (hidden_states), (attentions) than the forward(...) methods in PT tasks model?

@LysandreJik
Copy link
Member

Yes, that's what I mean. I think having this to be the same as the PyTorch API would make sense. It wouldn't be a breaking change either, as it would require the labels to be passed to the model.

I think doing this could still leverage Mixins, by calling a self._compute_loss or self.compute_loss if we want to expose this method as well. I have no strong opinion on that last item.

@jplu
Copy link
Contributor Author

jplu commented May 29, 2020

Ok, indeed makes sense and I don't think it is a problem to do that way, I will work on this today to see if there is any issue that would not allow us to do that.

@julien-c
Copy link
Member

julien-c commented Jun 1, 2020

I agree with @LysandreJik's 2nd point – maybe we can even take advantage of this to implement named tuples for TF models output, like @thomwolf and @patrickvonplaten intend to do for PyTorch (as it's going to be a breaking change in TF models anyways, maybe we can do this at the same time?)

@jplu
Copy link
Contributor Author

jplu commented Jun 1, 2020

Since my last commit, now the TF models return the loss such as the PT ones if the labels are given.

About the named tuples, looks to be a good idea indeed, but I think we should implement this in another PR in order to release this in same time than for PT. No?

@julien-c
Copy link
Member

julien-c commented Jun 1, 2020

About the named tuples [...] we should implement this in another PR in order to release this in same time than for PT. No?

Yes, makes sense!

@jplu
Copy link
Contributor Author

jplu commented Jun 1, 2020

Ok, looks good to me, I have tested the new models with different examples that use the trainer and they all work, tests looks to be ok as well except the quality one that I don't know how to fix 😄

@@ -25,7 +25,8 @@

from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes here! Looks good to me

@@ -734,7 +734,7 @@ def call(self, inputs, **kwargs):
return outputs


class TFTransfoXLLMHead(tf.keras.layers.Layer):
class TFTransfoXLWithLMHeadModel(tf.keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, no? Maybe we want to add an alias for TFTransfoXLLMHead for backward compatibility @LysandreJik

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is my bad, I just done a commit to rename it.

)

return optimizer
return optimizer, lr_schedule
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a breaking change - we should document this well so that the user knows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is not backwards compatible. @jplu, do you expect this method to be currently used by users outside of the Trainer? Would this breaking change impact those users?

Copy link
Contributor Author

@jplu jplu Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly no, I not expecting users to use this outside the TF Trainer. Also the TF trainer has been updated to use this new return format, such as the PT one. Including the examples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I thought so too!

@@ -3,9 +3,9 @@
from dataclasses import dataclass, field
from typing import Optional

import matplotlib.pyplot as plt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had problems with isort on this file as well :D I think you might just want to reverse this change manually to fix isort. It seems like you also have the wrong isort version...quite annoying this isort bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ! And thanks for the hint 😄

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 3, 2020

A more general question regarding training in TensorFlow (I'm not super familiar with TF 2.0 training, so I'm asking primarily to learn a bit :-) ):
I remember that when TF 2.0 was not out, most people used Keras to train a model with
model.fit(x_train, y_train) => is this still the case?
or are people more and more switching to the TF 2.0 training style as shown here: https://www.tensorflow.org/tutorials/quickstart/advanced and which basically consists of using
optimizer.apply_gradients(zip(gradients, model.trainable_variables)). This is also what we do in the TF trainer right?

Was it possible and recommended to train transformer models with keras' model.train() before TF Trainer and is it still possible now?

@jplu
Copy link
Contributor Author

jplu commented Jun 3, 2020

This is a good question! Short answer: yes it is still possible but witthout any gradient accumulation, that's mostly why the trainer uses the advanced training of TensorFlow.

I'm currently preparing a next PR that will integrate the new Model.train_step feature added in TF 2.2. Basically this update allows you to create your own train step, and then integrate the missing gradient accumulation but this new PR will be only for TF >= 2.2.

@LysandreJik
Copy link
Member

@patrickvonplaten It was possible and we definitely aim to keep compatibility with keras' fit method. We don't have many tutorials that cover it, though, having some would probably make it easier for new users coming from Keras to use our lib.

@julien-c, we've had the offline approval from @thomwolf, feel free to merge when you want. Glad to welcome this in the library!

@julien-c
Copy link
Member

julien-c commented Jun 4, 2020

Just tweaked the training_args.logging_dir to keep the same default as pytorch (I like that it creates a new subfolder each time you relaunch a training)

Great job @jplu, thank you 💪

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants