Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat][minor] 2/3 Make it explicit whether an attention mechanism supports a mask #266

Merged
merged 1 commit into from
Apr 21, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Apr 9, 2022

What does this PR do?

Preamble to the Triton2 PR (since triton 2 will change the blocksparse attention and not support attention masks anymore)

The unit test which now fails on CI should be unrelated to this PR, and is fixed by the next PR in line, triton2

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2022
@blefaudeux blefaudeux marked this pull request as draft April 9, 2022 17:15
@@ -150,7 +150,7 @@ run_unittests: &run_unittests
- run:
name: Run Unit Tests
command: |
pytest --junitxml=test-results/junit.xml --verbose --timeout 600 tests
CUDA_LAUNCH_BLOCKING=1 pytest --junitxml=test-results/junit.xml --verbose --timeout 600 tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa CI crash happens with or without this, I added that to make sure that it was not another test before which was silently failing. The crash is in a sputnik kernel, but it makes little sense to me given the contents of this PR

@@ -39,7 +39,7 @@
"num_heads": 4,
"residual_dropout": 0,
"attention": {
"name": "linformer",
"name": "scaled_dot_product",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were actually passing an attention mask in this test, which makes little sense with linformer. Now that we assert in this case, it had to be fixed :)

@blefaudeux blefaudeux force-pushed the label_attention_properties branch 6 times, most recently from db42817 to be72b26 Compare April 13, 2022 04:11
@blefaudeux blefaudeux changed the base branch from main to conda_ci April 13, 2022 05:02
@blefaudeux blefaudeux changed the title [feat][minor] Make it explicit whether an attention mechanism supports a mask [feat][minor] 2/3 Make it explicit whether an attention mechanism supports a mask Apr 18, 2022
@blefaudeux blefaudeux marked this pull request as ready for review April 19, 2022 21:25
@blefaudeux
Copy link
Contributor Author

@dianaml0 @finassa this goes with the triton2 PR again

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

I have only one comment (which is sort of unrelated to this PR), but which I'd love to get your opinion on.

@@ -53,6 +53,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I think I would prefer if we only support attention_mask, and then support the "optimized" case of key_padding_mask internally by checking the strides of the attention_mask Tensor.
The key_padding_mask case is actually just assuming that the stride of dimension -2 of the attention_mask is 0. So the converting key_padding_mask to a attention_mask can be done efficiently via

attn_mask = key_padding_mask[:, None, None, :].expand(batch, heads, query_len, key_len)

Indeed, are we sure that all our implementations support both attention_mask and key_padding_mask at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on the simpler interface, it's a bit confusing right now. One trick is that for some attentions (Nystrom for instance) there's no attention mask but there's a possible key padding mask (which can then generate a mask of sorts indeed, but in that case the dimensions will differ). I think we could do a second pass on that indeed, thoughts @dianaml0 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree ideally we could have this simpler approach but I ended up adding a key padding mask argument separately for Nystrom since Nystrom can't apply attention masks. I couldn't come up with a way around that but maybe there's a better way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so ideally we would just support key padding for nystrom and not make it visible everywhere ? right now there's a global flag because this happens in the MHA, but maybe that we could more this one level down and not expose this for other mechanisms ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else it feels like an abstraction leak on our end, not perfect.. could be worth some more thinking

@blefaudeux blefaudeux merged commit e3b57de into conda_ci Apr 21, 2022
blefaudeux added a commit that referenced this pull request Apr 21, 2022
…h combo (#271)

* testing using conda to get the pytorch nightlies and matching cuda

* [fix] Making it explicit whether the attention mechanism supports an attention mask or not (#266)

check the assert

* [backend] 3/3 Triton 2 update (#272)

* parent be72b26
author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100
committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700

Move to Triton 2

Author:    Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me>

Tentatively fixing layernorm

- faster all around
- bugfix

better take on sparse tensors, put layout on the correct device
update the pip packages, minor cleanup

* catering for triton blocksparse being probably more reliable in fp16

* faster layernorm

* Minor blocksparse refactoring, update block size restrictions, relax power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>

* code review, thanks @fmassa !

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com>
Co-authored-by: Cole Hawkins <colehawk@amazon.com>

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com>
Co-authored-by: Cole Hawkins <colehawk@amazon.com>
@blefaudeux blefaudeux deleted the label_attention_properties branch April 25, 2022 03:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants