Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.compile with fullgraph=True when attention_mask input is used #29211

Merged
merged 3 commits into from Feb 22, 2024

Conversation

fxmarty
Copy link
Collaborator

@fxmarty fxmarty commented Feb 22, 2024

As per title.

Fixes #29190

@fxmarty
Copy link
Collaborator Author

fxmarty commented Feb 22, 2024

Let's consider using pytorch/pytorch#120400 if this is accepted and released.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't have a choice?

# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of related to #29210

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker this would conflict but is unrelated

Comment on lines +1081 to +1085
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

@fxmarty
Copy link
Collaborator Author

fxmarty commented Feb 22, 2024

I guess we don't have a choice?

The other choice would be:

            is_tracing = (
                torch.jit.is_tracing()
                or isinstance(input_tensor, torch.fx.Proxy)
                or torch._dynamo.is_fullgraph_tracing())
            )

but torch._dynamo.is_fullgraph_tracing does not exist in PyTorch.

One other possibility is to always do causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)) no matter what.

One other possibility is to not use causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)) at all, but that means that we drop support for memory-efficient attention backend in Transformers cc @drisspg

One other possibility is to move the causal mask logic outside of the modeling code.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's make sure CIs are green en bench are not slower!

@fxmarty fxmarty merged commit 2cc8cf6 into huggingface:main Feb 22, 2024
19 checks passed
@fxmarty fxmarty mentioned this pull request Feb 22, 2024
4 tasks
@kwen2501
Copy link
Contributor

Thanks for the fix!
Just wanted to share new API offerings from PyTorch (as of a couple days ago):
pytorch/pytorch#119602

Summary:

  • A more general flag is named torch.compiler.is_compiling(). This flag indicates whether a graph is traced/compiled via torch.export() or torch.compile(). The flag works even in non-strict mode (i.e. even without TorchDynamo).
  • A more specific flag is named torch.compiler.is_dynamo_compiling(), it's stricter, because it's only set to True when TorchDynamo is used, so, in non-strict mode it would be False.

Cc: @mreso @khabinov

@fxmarty
Copy link
Collaborator Author

fxmarty commented Feb 27, 2024

@kwen2501 Thank you! torch.export is not always using dynamo? Reading https://pytorch.org/docs/stable/export.html I thought so!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.export fails for llama model
4 participants