Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Nystrom + microGPT + some additive masking #75

Merged
merged 7 commits into from Nov 4, 2021

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Nov 3, 2021

What does this PR do?

  • Fixes [bug] Nystrom fails with microGPT #74, microGPT and causal attention work properly with Nystrom
  • Fix attention masks being bool or additive (not all the codepaths probably, but at least the ones that Nystrom touched)
  • Fix the small-sequence handling for sparse masks, this was broken prior to this PR and probably explains the hanging CI yesterday

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 3, 2021
@blefaudeux blefaudeux force-pushed the nystrom_causal_fix branch 2 times, most recently from 5eddcff to b52b135 Compare November 4, 2021 00:19
@blefaudeux blefaudeux changed the title [DRAFT] Nystrom + microGPT Now running [DRAFT] Nystrom + microGPT Nov 4, 2021
seq = q.shape[-2]

if not att_mask.is_sparse:
att_mask = att_mask[:seq, :seq]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we were testing for smaller sequences, but when the mask was sparse it was not adjusted -> possible memory error which showed up on CI later on

@blefaudeux blefaudeux changed the title [DRAFT] Nystrom + microGPT [fix] Nystrom + microGPT Nov 4, 2021
@blefaudeux blefaudeux changed the title [fix] Nystrom + microGPT [fix] Nystrom + microGPT + some additive masking Nov 4, 2021
@blefaudeux blefaudeux marked this pull request as draft November 4, 2021 05:50
@blefaudeux
Copy link
Contributor Author

converting back to draft, I think that the key padding mask is not handled correctly with my changes

mask = (
key_padding_mask
if mask is None
else mask.logical_and(key_padding_mask)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dianaml0 I'm not sure of how that worked, since mask and key_padding_mask had different dimensions here, no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should automatically broadcast key_padding_mask along the mismatched dimension, similar to https://github.com/pytorch/pytorch/blob/4262c8913c2bddb8d91565888b4871790301faba/torch/nn/functional.py#L5189

@codecov-commenter
Copy link

Codecov Report

Merging #75 (013a927) into main (962db66) will increase coverage by 0.01%.
The diff coverage is 87.87%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #75      +/-   ##
==========================================
+ Coverage   87.10%   87.12%   +0.01%     
==========================================
  Files          50       50              
  Lines        2428     2447      +19     
==========================================
+ Hits         2115     2132      +17     
- Misses        313      315       +2     
Flag Coverage Δ
Python 87.12% <87.87%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...formers/components/attention/scaled_dot_product.py 97.29% <ø> (ø)
xformers/components/attention/core.py 87.50% <81.81%> (-1.14%) ⬇️
xformers/components/attention/nystrom.py 89.69% <85.71%> (+0.80%) ⬆️
xformers/components/attention/_sputnik_sparse.py 95.55% <100.00%> (+0.06%) ⬆️
xformers/components/attention/utils.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 962db66...013a927. Read the comment docs.

@blefaudeux blefaudeux marked this pull request as ready for review November 4, 2021 15:35
Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! Really helpful changes! Makes masking easier to work with!

@@ -30,7 +30,7 @@ def test_core_attention():
def test_core_attention_mask_types():

b, s, d = 8, 900, 32
prob = 0.5
prob = 0.8 # make sure that we trigger the sparse kernels
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh oops, thanks for catching that!

mask = (
key_padding_mask
if mask is None
else mask.logical_and(key_padding_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should automatically broadcast key_padding_mask along the mismatched dimension, similar to https://github.com/pytorch/pytorch/blob/4262c8913c2bddb8d91565888b4871790301faba/torch/nn/functional.py#L5189

)
key_padding_mask = bool_mask_to_additive(key_padding_mask)

assert key_padding_mask is not None # mypy is drunk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if a return type is added to bool_mask_to_additive it may fix the mypy error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh good point, I'll fix that, thank you !

@blefaudeux blefaudeux merged commit 1c702fb into main Nov 4, 2021
@blefaudeux blefaudeux deleted the nystrom_causal_fix branch November 4, 2021 17:13
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
* Add some 2d-specific attention patterns
* Add notebook with examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] Nystrom fails with microGPT
4 participants