Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing output_attentions in PT/Flax equivalence test #16271

Merged
merged 10 commits into from Mar 29, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Mar 19, 2022

What does this PR do?

In a previous PR #15841, output_attentions was not set (I accidentally removed the whole block containing it).
This PR sets output_attentions to make the test more thorough.

The test still runs successfully with 1e-5 on both CPU/GPU. However, see the 2nd points in the remarks below.

It also adds has_attentions attribute to FlaxModelTesterMixin (as done in PyTorch's ModelTesterMixin).

Remarks:

  • In a follow up PR, we might use has_attentions in some existing methods (to make sure the attentions are only tested if has_attentions is True), see [Tests] Add attentions_option to ModelTesterMixin #15909
  • There are 4 Flax model testers overwrite the Flax common test_equivalence_pt_to_flax and test_equivalence_flax_to_pt.
    • I will update them in a next PR.
    • These include FlaxGPTJ and FlaxXGLM, which will fail with 1e-5. I need to debug them.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ydshieh ydshieh force-pushed the fix_pt_flax_equivalence_tests branch from 3d07016 to cb6459b Compare March 20, 2022 11:33
@ydshieh ydshieh marked this pull request as ready for review March 21, 2022 07:47
@ydshieh ydshieh changed the title [WIP] Fix missing output_attentions Fix missing output_attentions in PT/Flax equivalence test Mar 21, 2022
Copy link
Member

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

tests/test_modeling_flax_common.py Outdated Show resolved Hide resolved
@@ -178,6 +179,12 @@ def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit too hacky for me here. Can't we just overwrite the test in test_modeling_flax_big_bird.py?

@@ -274,7 +281,8 @@ def test_equivalence_flax_to_pt(self):

# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention
if self.has_attentions:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't like this too much here either. Can't we check if there is a output_attentions in the signature of the forward function and if that's the case then we set config.output_attentions=True? This way we have 1 dependency less

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has_attentions attribute was introduced in ModelTesterMixin (#15909) (and then in TFModelTesterMixin by me #16259).

Think it would be good to have the same approach for testing across the 3 frameworks. Let me know if you still prefer the other approach(es).

cc @NielsRogge @sgugger for further comments if any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's use existing attributes and make the three testers consistent with each other.

@ydshieh ydshieh force-pushed the fix_pt_flax_equivalence_tests branch from ede2bea to 151860f Compare March 25, 2022 16:19
@@ -314,6 +315,7 @@ def test_equivalence_flax_to_pt(self):

# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't forget to set to eval for re-loaded pt model

@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 25, 2022

Think this (quite small) PR is ready. Nothing particular but adding the missing config.output_attentions = self.has_attentions.
The super() thing was discussed in #16280.

Will merge it today.

@@ -168,6 +169,7 @@ def recursive_check(tuple_object, dict_object):
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
Copy link
Collaborator Author

@ydshieh ydshieh Mar 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is only to add this information, not mean to work with the current version of make fix-copies.

@sgugger Are you OK with this comment? Otherwise I can just remove it.

@ydshieh ydshieh merged commit aebca69 into huggingface:main Mar 29, 2022
@ydshieh ydshieh deleted the fix_pt_flax_equivalence_tests branch March 29, 2022 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants