-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better booleans handling in the TF models #8777
Conversation
or ("return_dict" in kwargs and kwargs["return_dict"] is not None) | ||
): | ||
logger.warn( | ||
"Cannot update the boolean parameters behavior in graph mode and the return_dict parameter is always True in that mode." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd make the warning a bit explicit and split it into 2:
- Cannot set boolean arguments
use_cache
,output_hidden_states
, andoutput_attentions
to True in graph mode
and then under new if("return_dict" in kwargs and kwargs["return_dict"] is not None)
return_dict
is always set to True in graph mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW you can have use_cache
, output_hidden_states
and output_attentions
to True
in graph mode, but it has to be done when instantiating the config
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mayne we can make that more explicit in the warning then? I hadn't understood it from the current one,
final_booleans["output_hidden_states"] = config.output_hidden_states | ||
|
||
if "return_dict" in kwargs: | ||
final_booleans["return_dict"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this is always True -> this looks like if someone sets config.use_return_dict=False
=> return_dict
would stay False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok never mind! It's always in kwargs because every forward has it...but do we really need the `if "return_dict" in kwags then? -> it should always be in there no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because only the values in final_booleans
are taken into account afterwards in the model/layers.
else config.output_hidden_states | ||
) | ||
|
||
if "return_dict" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need that check? Can't we assume that every forward method has it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few doesn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which ones? The should have it I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main layer of T5 doesn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me then!
@@ -348,6 +424,15 @@ def input_processing(func, input_ids, **kwargs): | |||
if "kwargs" in output: | |||
del output["kwargs"] | |||
|
|||
boolean_dict = {k: v for k, v in output.items() if k in boolean_properties} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) I'd actually move ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] directly into the line here -> replace boolean_properties with ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - it'd make it easier to read for me
self.output_attentions = config.output_attentions | ||
self.output_hidden_states = config.output_hidden_states | ||
self.return_dict = config.use_return_dict | ||
self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I always thought there was a reason why we never save the config in lower TF modules...but if there is not self.config = config
is totally fine for me @LysandreJik
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing prevent you to do this as long as you have a proper get_config()
method, which is our case with the keras_serializable
decorator. Afterward we can even think to remove all the self.parameter = config.parameter
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can save the config instead of storing the attributes, it's way cleaner, so I'm all for it :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I had to add these so that the saved model tests could pass. I re-ran the saved model tests on your PR and they pass! Great, thanks @jplu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the changes I did were done to the TFAlbertEmbeddings
class (which is not keras_serializable), not the TFAlbertMainLayer
class.
@@ -1302,29 +1284,33 @@ def call( | |||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) | |||
|
|||
past = ( | |||
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None | |||
(inputs["encoder_outputs"], decoder_outputs[1]) | |||
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need cast_bool_to_primitive
? I kinda thought this PR can resolve the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, but I wanted to make sure of that in a later PR that will focus only on the T5 issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I'm fine with the PR, but I was kind of hoping that this PR would resolve our cast_bool_to...
problem in TFT5 and TFBart? - will this still not be possible?
Also, it would be great if this PR could manage to remove the skipped tests for TFT5 and TFBart, e.g.:
transformers/tests/test_modeling_tf_t5.py
Line 290 in 90d5ab3
def test_saved_model_with_attentions_output(self): |
Thb, I don't see the big gain of the PR if it won't resolve this bigger problem we are having in TFT5 and TFBart (I guess with all TFSeq2Seq models...)
Thanks @patrickvonplaten! As detailed in the first post, boolean parameters cannot be set during the model call in graph mode. This is the major feature brought by this PR. I wanted to focus of TF T5 and TF Bart on a later PR once this logic is ok at least for all the others. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is way cleaner like this, thanks for fixing! And thanks for separating this from the PR that will fix T5 and then BART as for are indeed separate issues and should be addressed in separates PRs :-)
or ("return_dict" in kwargs and kwargs["return_dict"] is not None) | ||
): | ||
logger.warn( | ||
"Cannot update the boolean parameters behavior in graph mode and the return_dict parameter is always True in that mode." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mayne we can make that more explicit in the warning then? I hadn't understood it from the current one,
self.output_attentions = config.output_attentions | ||
self.output_hidden_states = config.output_hidden_states | ||
self.return_dict = config.use_return_dict | ||
self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can save the config instead of storing the attributes, it's way cleaner, so I'm all for it :-)
@@ -846,8 +841,10 @@ def call( | |||
training=False, | |||
**kwargs, | |||
): | |||
print(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftovers from debugging? Should be removed.
@@ -860,6 +857,7 @@ def call( | |||
training=training, | |||
kwargs_call=kwargs, | |||
) | |||
print(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftovers from debugging? Should be removed.
There is now a better warning message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this does a very good job at cleaning the models and making them more understandable. I like how you've put everything in the pre-processing function.
There are two things which seem important to me before we can consider merging:
- You say there is a breaking change in graph mode. Does it mean that currently, both eager & graph mode can handle arguments through the configuration & through the function call? I'm unsure on where we stand on this currently.
- It seems like the tests that would be impacted by these changes are the slow tests. Have you run the slow tests? If not, could you run the slow tensorflow tests on this PR? If you don't know how to do that, happy to show you how for next time.
self.output_attentions = config.output_attentions | ||
self.output_hidden_states = config.output_hidden_states | ||
self.return_dict = config.use_return_dict | ||
self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I had to add these so that the saved model tests could pass. I re-ran the saved model tests on your PR and they pass! Great, thanks @jplu
|
||
if "return_dict" in kwargs: | ||
if kwargs["return_dict"] is not None: | ||
logger.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking, but I believe logging.warn
is deprecated, and should be replaced by warning
. We have plenty of occurences where we use warn
so it's the nittiest nit ever, but it might be good to keep that in mind for future work.
self.output_attentions = config.output_attentions | ||
self.output_hidden_states = config.output_hidden_states | ||
self.return_dict = config.use_return_dict | ||
self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the changes I did were done to the TFAlbertEmbeddings
class (which is not keras_serializable), not the TFAlbertMainLayer
class.
Yes, both can be done, but it raises issues when through the function call in graph mode. So this PR fixes this with a better handling of this case.
This PR partially fixes these tests. Remembert that they do not pass for T5 and BART for the reasons expressed by Patrick. These models, including the saved model tests, will be fixed in same time in a PR just after this one. Also, in a future PR I will rethink the way the attributes are handled in all the layers. |
So right now it fails, and with this PR it also fails but with better error handling?
I meant all the slow tests, not only the saved models with saved attentions tests. And this PR doesn't only impact the T5 and BART models, so re-running all the slow tests on this PR seems necessary. |
No, before nothing was working in graph mode when the boolean was updated through the function call. Now, I disabled this functionality and there is no more fail, and everything works properly and as expected in eager+graph mode except T5 and BART in graph mode, which will be handled in a later PR.
Ok, I will run all of them. |
@LysandreJik All the slow tests are passing but two:
|
@LysandreJik Any other needs for this PR to be merged? |
I investigated why the No issues there it seems! Reviewing a final time and merging if all is good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, this looks good to me. Thanks a lot @jplu.
@patrickvonplaten you haven't approved this PR, do you want to give it a final look and merge if ok for you? |
I'll investigate for |
What does this PR do?
This PR provides a better handling for the booleans. More precisely, the execution mode (eager or graph) is detected and the booleans are accordingly set to have a proper execution. Nevertheless, this brings a small breaking change in graph mode, it is not possible anymore to update the booleans with the model parameters but only with through the config and the
return_dict
is forced to beTrue
.Now to activate the
output_attentions
oroutput_hidden_states
values in graph mode one has to create the model config like: