Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing slow pipeline tests #14260

Merged
merged 7 commits into from
Nov 4, 2021
Merged

Fixing slow pipeline tests #14260

merged 7 commits into from
Nov 4, 2021

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Nov 3, 2021

Some tests were broken because of pytorch inference_mode.
This should cover all cases of inplace tensor modifications afaik.

Let me know if there are better ways to fix those.

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stas00
@patrickvonplaten

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
mask = attention_mask.cumsum(dim=-1)[:, -1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) - could we call it non_padded_lengths - the idea here is to extract the sub sampled length from the "real non-padded" input length

Suggested change
mask = attention_mask.cumsum(dim=-1)[:, -1]
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, do you know any other way to do that operation ? it's very surprising that .sum is inplace, and I am scared that using cumsum instead is super wasteful.

Tried to grep it in our code, but I couldn't find anything of that sort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - don't really know to be honest...torch.sum(...) doesn't work either? But I think using .cumsum(...) is totally fine as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, torch.sum(..) doesn't work.


def sequential_inference(self, **inputs):
"""
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
handle conversational query related to a table.
"""
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be happy if we could give mask a better naming. Apart from that, thanks a lot for enabling inference mode for all models :-)

Comment on lines +651 to +654
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 IS that ok to remove at inference time ?

Copy link
Contributor

@stas00 stas00 Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory yes. In practice, it depends on how the model was pre-trained.

The model weights don't change during inference, so we don't need to keep things in check all the time.

However if the pre-trained model's weights lead to an overflow in a single iteration during training, as is the case with some mt5 models under mixed-precision then this can occur just as well during inference.

This is primarily an issue with pre-trained on bf16 models fine-tuned/inferenced on fp16 (mixed or non-mixed precision).

If a model was pretrained with fp16/mixed precision it's pretty sure the clamping won't be needed.

To give you a more intelligent answer it'd require running some tests with the actual DETR models and checking their activations magnitudes at the point you're asking about, which should be pretty trivial, using https://huggingface.co/transformers/debugging.html#underflow-and-overflow-detection, which can be plugged into HF Trainer and the examples with just a single cl arg --debug underflow_overflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I think this code was just badly copy pasted, so I'm more in favor of disabling this hack for training (as it is done now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, if everyone is favorable, then let's do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I think this code was just badly copy pasted, so I'm more in favor of disabling this hack for training (as it is done now)

you must have meant for inference, right Patrick?

@patrickvonplaten
Copy link
Contributor

Good for merge for me

@Narsil Narsil merged commit 68427c9 into huggingface:master Nov 4, 2021
@Narsil Narsil deleted the fix_slow_tests branch November 4, 2021 08:49
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* Fiixng slow pipeline tests

* Remove the image-segmentaiton override.

* Fixing clamping only in training.

* Wav2vec2.

* Remove last mention of `no_grad`.

* Fixing copies.

* Rename.
@sanchit-gandhi sanchit-gandhi mentioned this pull request Aug 24, 2023
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants