-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch from return_tuple to return_dict #6138
Conversation
@@ -51,12 +51,6 @@ | |||
model. Initializing with a config file does not load the weights associated with the model, only the | |||
configuration. | |||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | |||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the documentation of the init, not the forward, so this shouldn't have been added here.
@@ -53,12 +53,6 @@ | |||
config (:class:`~transformers.XLMRobertaConfig`): Model configuration class with all the parameters of the | |||
model. Initializing with a config file does not load the weights associated with the model, only the configuration. | |||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. | |||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the documentation of the init, not the forward, so this shouldn't have been added here.
@@ -661,9 +661,7 @@ def _prepare_inputs( | |||
|
|||
if self.args.past_index >= 0 and self._past is not None: | |||
inputs["mems"] = self._past | |||
# Our model outputs do not work with DataParallel, so forcing return tuple. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new model output works with DataParallel, so no precautions needed anymore.
Codecov Report
@@ Coverage Diff @@
## master #6138 +/- ##
==========================================
+ Coverage 78.49% 79.48% +0.99%
==========================================
Files 146 146
Lines 26335 26441 +106
==========================================
+ Hits 20671 21017 +346
+ Misses 5664 5424 -240
Continue to review full report at Codecov.
|
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for taking care of it!
I think we'll have to take care of the XLMForMultipleChoice
model which was added when you were coding this. Sorry about that 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
@sgugger thanks very much for this PR!
I still want to be able to use It looks like I could pass How should this be handled? Would the solution be something like this:
I tried this solution but it didn't work, it gave me the following error:
|
The right line is:
|
This is the first step in the change of model outputs as described on the forum.
This PR removes the argument
return_tuple
and introducesreturn_dict
(that works the other way round) and all models now return tuple by default (100% full backward compatibility) unless you opt-in the new model output types withreturn_dict=True
. The model output class is changed to the dict-like one that should work equally well for TensorFlow.I have normally updated all examples in the docs to instantiate the model with
return_dict=True
but more docs will follow in other PRs. For the tests, I have setreturn_dict=True
in one of the common tests just to make sure it actually works. Step 2 (in a follow-up PR) will be to use it in all tests.Step 3 is then going to update the TensorFlow models to use this
ModelOutput
.