Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Compositional attention #178

Merged
merged 8 commits into from Jan 20, 2022
Merged

[feat] Compositional attention #178

merged 8 commits into from Jan 20, 2022

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Jan 10, 2022

What does this PR do?

Implements Compositional Attention (based on the reference implementation), as mentioned in #41

Paper

TODOs

  • Sane defaults
  • Speedup wherever possible. Looks like it takes a lot of memory also at the moment, probably some dummy mistakes
  • Maybe self-attention optimization (single proj) -> doable if moving the projections within the attention to the inproj class, worth it?
  • Add a lot of explanations/documentations
  • [0] Some IR results ? -> that would be for another task probably ?

cc @sarthmit if interested

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 馃檭
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 10, 2022
@sarthmit
Copy link

Please feel free to reach out. I'd be happy to help :)

@blefaudeux blefaudeux force-pushed the comp_attention branch 3 times, most recently from 88e0226 to 6c2ae94 Compare January 10, 2022 05:53
@blefaudeux
Copy link
Contributor Author

Please feel free to reach out. I'd be happy to help :)

@sarthmit actually, I have a question: I've seen quite a few .contiguous().do_something() (for instance), it's not completely obvious to me that it's beneficial since you change the tensor structure right after that, and you don't seem to be using specific kernels which would require having a contiguous tensor in the first place (there are actually a bunch of kernels in xformers which would love that, but .contiguous would be last in that case, not first).
Is there another reason that I'm missing, TPUs for instance ?

@sarthmit
Copy link

@blefaudeux I think you are right, there is probably no use for those .contiguous(). I just used the fairseq multi-head codebase as the starting place and they seem to have it also, so I never bothered to remove it. See: multi-head

@codecov-commenter
Copy link

codecov-commenter commented Jan 11, 2022

Codecov Report

Merging #178 (ca0a693) into main (c16078b) will increase coverage by 0.39%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #178      +/-   ##
==========================================
+ Coverage   90.62%   91.02%   +0.39%     
==========================================
  Files          58       59       +1     
  Lines        2892     3019     +127     
==========================================
+ Hits         2621     2748     +127     
  Misses        271      271              
Flag Coverage 螖
Python 91.02% <100.00%> (+0.39%) 猬嗭笍

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage 螖
xformers/components/attention/favor.py 100.00% <酶> (酶)
xformers/components/__init__.py 100.00% <100.00%> (酶)
xformers/components/attention/attention_mask.py 98.50% <100.00%> (酶)
xformers/components/attention/base.py 97.22% <100.00%> (+0.16%) 猬嗭笍
xformers/components/attention/compositional.py 100.00% <100.00%> (酶)
xformers/components/multi_head_dispatch.py 100.00% <100.00%> (酶)

Continue to review full report at Codecov.

Legend - Click here to learn more
螖 = absolute <relative> (impact), 酶 = not affected, ? = missing data
Powered by Codecov. Last update c16078b...ca0a693. Read the comment docs.

@blefaudeux blefaudeux changed the title [DRAFT] Compositional attention [feat] Compositional attention Jan 12, 2022
@blefaudeux blefaudeux marked this pull request as draft January 12, 2022 16:48
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Jan 14, 2022

microGPT test (causal, autoregressive)

Friends of my soul: set down.
What stands before I came to come?

BENVOLIO:
Madam, her were born.

MERCUTIO:
Tut, such more sorrow we did not so long.
I have no stand of small as a torment for a
point a man else.

ANTONIO:
Where is the brother of him?

SEBASTIAN:
Or despite of all the sun.

ANTONIO:
Go, marry, of many choice, my mother,
Would the near I let thee be freed:
As with, sir, a brave form'd her by my mood
Shall twenty years a traitor to turn: tell them
Show'd their misfer when seems but when I shall never
Their princely hate with the stealing death.

DUKE OF AUMERLE:
I do fear thee, my daughter not stay.

KING RICHARD II:
Say that's not swine may know thou art.

QUEEN ELIZABETH:
To thou art deceived: if thou both part to tell,
Thou wouldst be thing in she foul born in presence.
If I know not, Grumio that thought it good
Is mutually: here's a stale, if thou beholdst him
Are our bodies cousin with customary here,
A fine schoolmaster; there be we fined,
No, nor do as so fast; it is too long

But loss not as good as a simple attention with the defaults (it's not really the target for this attention mechanism though):

EDIT: This was due to the rotary embeddings being used for the MHA, and not here (not supported with this PR). Without the rotary embeddings there's not a lot of difference on this task, the perf loss disappears
Screenshot from 2022-01-13 21-20-29

@blefaudeux blefaudeux marked this pull request as ready for review January 14, 2022 05:21
@blefaudeux
Copy link
Contributor Author

@sarthmit if you have cycles for a review, that would be great ! I've removed a few options (not that many actually) and aligned the terms with the other attentions here, + reused existing building blocks that we have (input projections for instance, there's an optimization to be used for self-attention

@@ -11,7 +11,7 @@ Please note that:
- These numbers are dependent of hyperparameters (dimensions chosen for Linformer, sparsity of the pattern), they are mostly an illustration
- The sparse attention patterns tested here are just presets, as explained in the linked notebook generating any new sparse attention pattern should be relatively easy, while keeping the benefits of optimized computations.

Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py --activations gelu --plot -emb 256 -bs 32 -heads 16`
Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py --activations gelu --plot -emb 256 -bs 8 -heads 4`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reducing the memory load, generating a new graph for everyone

@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Added
- Compositional Attention [#41]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hotfix, this was needed actually, not directly tied to this PR

@@ -43,10 +43,12 @@ def _get_multihead(
"dropout": attn_dropout,
"causal": causal,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1,
"window_size": SEQ // 8 + 1, # local attention
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to align the different field names and reformulate wherever possible, but there are still some specificities. It's more visible here since it's a specific attention unit test, but for real the new fields were already needed for the MHA and they don't need to be duplicated

assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered"


@pytest.mark.parametrize("attn_dropout", [0.0, 0.3])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code coverage, this attention exposes quite a few knobs which are mostly orthogonal to the other attentions, so I figured it was best to cover them in a dedicated unit test

@@ -80,7 +80,7 @@
"eval_frequency": 50,
"num_train_steps": 10000,
"num_eval_steps": 62,
"gradient_accumulation": 1
"gradient_accumulation": 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compositional takes more memory, so this task does not pass on a 8GB gpu (desktop 3080) without bumping up the accumulation

@@ -48,6 +48,11 @@ def build_multi_head_attention(
"num_heads"
]

if "dim_model" not in multi_head_config["attention"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convenience, remove the need for field duplication in between attention and MHA

@@ -49,7 +49,7 @@ def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
"""
assert x.dtype == torch.bool

additive_mask = torch.empty_like(x, dtype=torch.float)
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already doing that by default I believe, just making it a more explicit

dim_key = _either_or(dim_key, dim_model)
dim_value = _either_or(dim_value, dim_model)

self.in_proj_container = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

difference with the original implementation: reuse xformers input projection, because the task is the same, and there are some optimizations to be had here (if self attention, project once to get q,k,v instead of three calls for instance)

Copy link
Contributor

@dianaml0 dianaml0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting approach!

if att_mask_additive is not None:
attn_weights += att_mask_additive.values

if _is_triton_available:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use the softmax here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, fixing that !

@sarthmit
Copy link

@sarthmit if you have cycles for a review, that would be great ! I've removed a few options (not that many actually) and aligned the terms with the other attentions here, + reused existing building blocks that we have (input projections for instance, there's an optimization to be used for self-attention

@blefaudeux What do you mean by cycles for a review?

By the way, there is also an easy way to test/debug the code. If you set the Value Scores in Equation 13 to the Identity matrix (here), then you should recover the standard Multi-Head Attention performance.

@blefaudeux
Copy link
Contributor Author

@sarthmit if you have cycles for a review, that would be great ! I've removed a few options (not that many actually) and aligned the terms with the other attentions here, + reused existing building blocks that we have (input projections for instance, there's an optimization to be used for self-attention

@blefaudeux What do you mean by cycles for a review?

By the way, there is also an easy way to test/debug the code. If you set the Value Scores in Equation 13 to the Identity matrix (here), then you should recover the standard Multi-Head Attention performance.

I meant "brain cycles" :) (aka free time). I can have a look later today, but the logic should be maintained

@blefaudeux
Copy link
Contributor Author

@sarthmit if you have cycles for a review, that would be great ! I've removed a few options (not that many actually) and aligned the terms with the other attentions here, + reused existing building blocks that we have (input projections for instance, there's an optimization to be used for self-attention

@blefaudeux What do you mean by cycles for a review?

By the way, there is also an easy way to test/debug the code. If you set the Value Scores in Equation 13 to the Identity matrix (here), then you should recover the standard Multi-Head Attention performance.

it's a good idea I think, testing that right now, I'm still seeing a difference in perf, investigating. I can see that there are other differences with the Vaswani paper, for instance the normalization is not the same (/sqrt(dim_head) vs /sqrt(dim model)), but I don't think that's enough to explain the difference

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Jan 20, 2022

ok, confirmed @sarthmit @dianaml0 when forcing the values score to be identity the perf are in line with MHA: the difference in the above is due to.. the rotary embeddings ! These are not part of this PR but they were used in the comparison above for the classical mechanism, and they explained the gap (without them the two mechanisms are aligned if the value score is identity, as expected). All good for landing I think, and it means that supporting rotary embeddings in compositional attention could be a follow up

@blefaudeux blefaudeux merged commit cdbe195 into main Jan 20, 2022
@blefaudeux blefaudeux deleted the comp_attention branch January 20, 2022 04:32
@sarthmit
Copy link

@blefaudeux Awesome! I can try sparing some time if needed, is there anything in particular you want me to look at?

I would also be curious to see if you see performance gains on any of the tasks (without the identity matrix)

@blefaudeux
Copy link
Contributor Author

@blefaudeux Awesome! I can try sparing some time if needed, is there anything in particular you want me to look at?

I would also be curious to see if you see performance gains on any of the tasks (without the identity matrix)

I've not checked LRA but that could be interesting to get the scores, even if I think that LRA could do with more and better tasks (see this paper from @xwhan for instance https://arxiv.org/pdf/2112.07210.pdf). Rotary embeddings also seem to make a big difference with NLP, would be nice to support them with compositional, should not be a big change.

There's a definite perf impact though (at least with our implementations, even when pulling in some semi-optimized parts from xformers), it could be another axis to look at because I think that a lot of folks will "just" bump up the number of heads with the scaled dot product instead if compositional has too much of a speed/memory impact

@sarthmit
Copy link

@blefaudeux Awesome! I can try sparing some time if needed, is there anything in particular you want me to look at?
I would also be curious to see if you see performance gains on any of the tasks (without the identity matrix)

I've not checked LRA but that could be interesting to get the scores, even if I think that LRA could do with more and better tasks (see this paper from @xwhan for instance https://arxiv.org/pdf/2112.07210.pdf). Rotary embeddings also seem to make a big difference with NLP, would be nice to support them with compositional, should not be a big change.

There's a definite perf impact though (at least with our implementations, even when pulling in some semi-optimized parts from xformers), it could be another axis to look at because I think that a lot of folks will "just" bump up the number of heads with the scaled dot product instead if compositional has too much of a speed/memory impact

Thanks, I will check the paper out.

How big is the perf impact? In favour of MHA or in favour of compositional?

@blefaudeux
Copy link
Contributor Author

@blefaudeux Awesome! I can try sparing some time if needed, is there anything in particular you want me to look at?
I would also be curious to see if you see performance gains on any of the tasks (without the identity matrix)

I've not checked LRA but that could be interesting to get the scores, even if I think that LRA could do with more and better tasks (see this paper from @xwhan for instance https://arxiv.org/pdf/2112.07210.pdf). Rotary embeddings also seem to make a big difference with NLP, would be nice to support them with compositional, should not be a big change.

There's a definite perf impact though (at least with our implementations, even when pulling in some semi-optimized parts from xformers), it could be another axis to look at because I think that a lot of folks will "just" bump up the number of heads with the scaled dot product instead if compositional has too much of a speed/memory impact

Thanks, I will check the paper out.

How big is the perf impact? In favour of MHA or in favour of compositional?

By perf I meant compute time impact, sorry that wasn't clear. Compositional (this PR at least) is significantly slower than MHA (from this repo again) for the same number of heads, even if I understand that it captures way more relations. Could be worth a second eye on whether there are low hanging fruits?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants