Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better booleans handling in the TF models #8777

Merged
merged 53 commits into from
Dec 4, 2020

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Nov 25, 2020

What does this PR do?

This PR provides a better handling for the booleans. More precisely, the execution mode (eager or graph) is detected and the booleans are accordingly set to have a proper execution. Nevertheless, this brings a small breaking change in graph mode, it is not possible anymore to update the booleans with the model parameters but only with through the config and the return_dict is forced to be True.

Now to activate the output_attentions or output_hidden_states values in graph mode one has to create the model config like:

config = XConfig.from_pretrained("name", output_attentions=True, output_hidden_states=True)

or ("return_dict" in kwargs and kwargs["return_dict"] is not None)
):
logger.warn(
"Cannot update the boolean parameters behavior in graph mode and the return_dict parameter is always True in that mode."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make the warning a bit explicit and split it into 2:

  1. Cannot set boolean arguments use_cache, output_hidden_states, and output_attentions to True in graph mode
    and then under new if ("return_dict" in kwargs and kwargs["return_dict"] is not None)
  2. return_dict is always set to True in graph mode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW you can have use_cache, output_hidden_states and output_attentions to True in graph mode, but it has to be done when instantiating the config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mayne we can make that more explicit in the warning then? I hadn't understood it from the current one,

final_booleans["output_hidden_states"] = config.output_hidden_states

if "return_dict" in kwargs:
final_booleans["return_dict"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this is always True -> this looks like if someone sets config.use_return_dict=False => return_dict would stay False?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok never mind! It's always in kwargs because every forward has it...but do we really need the `if "return_dict" in kwags then? -> it should always be in there no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because only the values in final_booleans are taken into account afterwards in the model/layers.

else config.output_hidden_states
)

if "return_dict" in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need that check? Can't we assume that every forward method has it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few doesn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which ones? The should have it I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main layer of T5 doesn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me then!

@@ -348,6 +424,15 @@ def input_processing(func, input_ids, **kwargs):
if "kwargs" in output:
del output["kwargs"]

boolean_dict = {k: v for k, v in output.items() if k in boolean_properties}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I'd actually move ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] directly into the line here -> replace boolean_properties with ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] - it'd make it easier to read for me

self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.config = config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always thought there was a reason why we never save the config in lower TF modules...but if there is not self.config = config is totally fine for me @LysandreJik

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing prevent you to do this as long as you have a proper get_config() method, which is our case with the keras_serializable decorator. Afterward we can even think to remove all the self.parameter = config.parameter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can save the config instead of storing the attributes, it's way cleaner, so I'm all for it :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I had to add these so that the saved model tests could pass. I re-ran the saved model tests on your PR and they pass! Great, thanks @jplu

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the changes I did were done to the TFAlbertEmbeddings class (which is not keras_serializable), not the TFAlbertMainLayer class.

@@ -1302,29 +1284,33 @@ def call(
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)

past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
(inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need cast_bool_to_primitive? I kinda thought this PR can resolve the problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, but I wanted to make sure of that in a later PR that will focus only on the T5 issues.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I'm fine with the PR, but I was kind of hoping that this PR would resolve our cast_bool_to... problem in TFT5 and TFBart? - will this still not be possible?

Also, it would be great if this PR could manage to remove the skipped tests for TFT5 and TFBart, e.g.:

def test_saved_model_with_attentions_output(self):

Thb, I don't see the big gain of the PR if it won't resolve this bigger problem we are having in TFT5 and TFBart (I guess with all TFSeq2Seq models...)

@jplu
Copy link
Contributor Author

jplu commented Nov 25, 2020

Thanks @patrickvonplaten!

As detailed in the first post, boolean parameters cannot be set during the model call in graph mode. This is the major feature brought by this PR. I wanted to focus of TF T5 and TF Bart on a later PR once this logic is ok at least for all the others.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is way cleaner like this, thanks for fixing! And thanks for separating this from the PR that will fix T5 and then BART as for are indeed separate issues and should be addressed in separates PRs :-)

or ("return_dict" in kwargs and kwargs["return_dict"] is not None)
):
logger.warn(
"Cannot update the boolean parameters behavior in graph mode and the return_dict parameter is always True in that mode."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mayne we can make that more explicit in the warning then? I hadn't understood it from the current one,

self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.config = config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can save the config instead of storing the attributes, it's way cleaner, so I'm all for it :-)

@@ -846,8 +841,10 @@ def call(
training=False,
**kwargs,
):
print(input_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftovers from debugging? Should be removed.

@@ -860,6 +857,7 @@ def call(
training=training,
kwargs_call=kwargs,
)
print(inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftovers from debugging? Should be removed.

@jplu
Copy link
Contributor Author

jplu commented Nov 25, 2020

There is now a better warning message.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this does a very good job at cleaning the models and making them more understandable. I like how you've put everything in the pre-processing function.

There are two things which seem important to me before we can consider merging:

  • You say there is a breaking change in graph mode. Does it mean that currently, both eager & graph mode can handle arguments through the configuration & through the function call? I'm unsure on where we stand on this currently.
  • It seems like the tests that would be impacted by these changes are the slow tests. Have you run the slow tests? If not, could you run the slow tensorflow tests on this PR? If you don't know how to do that, happy to show you how for next time.

self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.config = config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I had to add these so that the saved model tests could pass. I re-ran the saved model tests on your PR and they pass! Great, thanks @jplu


if "return_dict" in kwargs:
if kwargs["return_dict"] is not None:
logger.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking, but I believe logging.warn is deprecated, and should be replaced by warning. We have plenty of occurences where we use warn so it's the nittiest nit ever, but it might be good to keep that in mind for future work.

self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.return_dict = config.use_return_dict
self.config = config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the changes I did were done to the TFAlbertEmbeddings class (which is not keras_serializable), not the TFAlbertMainLayer class.

@jplu
Copy link
Contributor Author

jplu commented Nov 26, 2020

You say there is a breaking change in graph mode. Does it mean that currently, both eager & graph mode can handle arguments through the configuration & through the function call? I'm unsure on where we stand on this currently.

Yes, both can be done, but it raises issues when through the function call in graph mode. So this PR fixes this with a better handling of this case.

It seems like the tests that would be impacted by these changes are the slow tests. Have you run the slow tests? If not, could you run the slow tensorflow tests on this PR? If you don't know how to do that, happy to show you how for next time.

This PR partially fixes these tests. Remembert that they do not pass for T5 and BART for the reasons expressed by Patrick. These models, including the saved model tests, will be fixed in same time in a PR just after this one.

Also, in a future PR I will rethink the way the attributes are handled in all the layers.

@LysandreJik
Copy link
Member

Yes, both can be done, but it raises issues when through the function call in graph mode. So this PR fixes this with a better handling of this case.

So right now it fails, and with this PR it also fails but with better error handling?

This PR partially fixes these tests. Remembert that they do not pass for T5 and BART for the reasons expressed by Patrick. These models, including the saved model tests, will be fixed in same time in a PR just after this one.

I meant all the slow tests, not only the saved models with saved attentions tests. And this PR doesn't only impact the T5 and BART models, so re-running all the slow tests on this PR seems necessary.

@jplu
Copy link
Contributor Author

jplu commented Nov 27, 2020

So right now it fails, and with this PR it also fails but with better error handling?

No, before nothing was working in graph mode when the boolean was updated through the function call. Now, I disabled this functionality and there is no more fail, and everything works properly and as expected in eager+graph mode except T5 and BART in graph mode, which will be handled in a later PR.

I meant all the slow tests, not only the saved models with saved attentions tests. And this PR doesn't only impact the T5 and BART models, so re-running all the slow tests on this PR seems necessary.

Ok, I will run all of them.

@jplu
Copy link
Contributor Author

jplu commented Nov 29, 2020

@LysandreJik All the slow tests are passing but two:

  • tests/test_modeling_tf_transfo_xl.py::TFTransfoXLModelLanguageGenerationTest::test_lm_generate_transfo_xl_wt103, I started to see that with @patrickvonplaten
  • tests/test_utils_check_copies.py::CopyCheckTester::test_is_copy_consistent, @sgugger any idea why this test don't pass anymore? Here the output:
def test_is_copy_consistent(self):
        # Base copy consistency
>       self.check_copy_consistency(
            "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
            "BertLMPredictionHead",
            REFERENCE_CODE + "\n",
        )

tests\test_utils_check_copies.py:71:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _  
tests\test_utils_check_copies.py:59: in check_copy_consistency
    self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
E   AssertionError: False is not true

@jplu
Copy link
Contributor Author

jplu commented Dec 3, 2020

@LysandreJik Any other needs for this PR to be merged?

@LysandreJik
Copy link
Member

I investigated why the test_is_copy_consistent test failed, that is probably because you launched your command from inside the tests/ directory, and it has a path hardcoded to src/transformers, and therefore cannot find the path tests/src/transformers.

No issues there it seems! Reviewing a final time and merging if all is good.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this looks good to me. Thanks a lot @jplu.

@LysandreJik
Copy link
Member

@patrickvonplaten you haven't approved this PR, do you want to give it a final look and merge if ok for you?

@patrickvonplaten
Copy link
Contributor

@LysandreJik All the slow tests are passing but two:

  • tests/test_modeling_tf_transfo_xl.py::TFTransfoXLModelLanguageGenerationTest::test_lm_generate_transfo_xl_wt103, I started to see that with @patrickvonplaten
  • tests/test_utils_check_copies.py::CopyCheckTester::test_is_copy_consistent, @sgugger any idea why this test don't pass anymore? Here the output:
def test_is_copy_consistent(self):
        # Base copy consistency
>       self.check_copy_consistency(
            "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
            "BertLMPredictionHead",
            REFERENCE_CODE + "\n",
        )

tests\test_utils_check_copies.py:71:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _  
tests\test_utils_check_copies.py:59: in check_copy_consistency
    self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
E   AssertionError: False is not true

I'll investigate for tests/test_modeling_tf_transfo_xl.py::TFTransfoXLModelLanguageGenerationTest::test_lm_generate_transfo_xl_wt103 -> thanks for pinging me on that! PR is good for me!

@LysandreJik LysandreJik merged commit dcd3046 into huggingface:master Dec 4, 2020
@jplu jplu deleted the bool-proc branch December 4, 2020 16:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants