New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Compositional attention #178
Changes from 5 commits
8d4fa5d
42629fa
b48293e
79cdac9
f17a70c
582722e
ca0a693
a7e310c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), | |
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||
|
||
## TBD | ||
### Added | ||
- Compositional Attention [#41] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hotfix, this was needed actually, not directly tied to this PR |
||
|
||
## [0.0.8] - 2022-01-07 | ||
### Fixed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,6 @@ | |
tensorboard>=2.3.0 | ||
tensorflow>=2.3.1 | ||
tensorflow-datasets>=4.0.1 | ||
tensorflow-text>=2.7.3 | ||
submitit | ||
fvcore |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,10 +43,12 @@ def _get_multihead( | |
"dropout": attn_dropout, | ||
"causal": causal, | ||
"seq_len": SEQ, | ||
"window_size": SEQ // 8 + 1, | ||
"window_size": SEQ // 8 + 1, # local attention | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to align the different field names and reformulate wherever possible, but there are still some specificities. It's more visible here since it's a specific attention unit test, but for real the new fields were already needed for the MHA and they don't need to be duplicated |
||
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, | ||
"dim_model": MODEL, | ||
"num_heads": heads, | ||
"dim_head": MODEL / heads, | ||
"num_rules": 2, # Compositional Attention | ||
} | ||
|
||
if skip_output_projection: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
from xformers.components import MultiHeadDispatch | ||
|
||
# Automatically test all the registered attentions | ||
from xformers.components.attention import ( | ||
_DENSITY_THRESHOLD, | ||
ATTENTION_REGISTRY, | ||
build_attention, | ||
) | ||
|
||
DEVICES = ( | ||
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")] | ||
) | ||
|
||
BATCH = 2 | ||
SEQ = 128 if torch.cuda.is_available() else 32 | ||
MODEL = 128 if torch.cuda.is_available() else 64 | ||
GLOBAL_ATTENTION_RATIO = ( | ||
_DENSITY_THRESHOLD * 0.9 | ||
) # Make sure that we test the sparse implementation, no matter the threshold | ||
|
||
assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" | ||
|
||
|
||
@pytest.mark.parametrize("attn_dropout", [0.0, 0.3]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. code coverage, this attention exposes quite a few knobs which are mostly orthogonal to the other attentions, so I figured it was best to cover them in a dedicated unit test |
||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("heads", [1, 4]) | ||
@pytest.mark.parametrize("rules", [1, 4]) | ||
@pytest.mark.parametrize("q_compose", [False, True]) | ||
@pytest.mark.parametrize("dim_selection", [MODEL // 2, None]) | ||
@pytest.mark.parametrize("bias", [True, False]) | ||
@pytest.mark.parametrize("qk_rule", [True, False]) | ||
@pytest.mark.parametrize("nonlinear", [True, False]) | ||
@pytest.mark.parametrize("device", DEVICES) | ||
def test_build_and_run( | ||
heads: int, | ||
attn_dropout: float, | ||
causal: bool, | ||
rules: int, | ||
q_compose: bool, | ||
dim_selection: int, | ||
bias: bool, | ||
qk_rule: bool, | ||
nonlinear: bool, | ||
device: torch.device, | ||
): | ||
|
||
torch.manual_seed(42) | ||
|
||
test_config = { | ||
"name": "compositional", | ||
"dropout": attn_dropout, | ||
"causal": causal, | ||
"seq_len": SEQ, | ||
"window_size": SEQ // 8 + 1, # local attention | ||
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, | ||
"dim_model": MODEL, | ||
"num_heads": heads, | ||
"num_rules": 2, # Compositional Attention | ||
"q_compose": q_compose, | ||
"rules": rules, | ||
"dim_selection": dim_selection, | ||
"bias": bias, | ||
"qk_rule": qk_rule, | ||
"nonlinear": nonlinear, | ||
} | ||
|
||
# Add some blocksparse layout to test the corresponding attention | ||
block_size = 16 | ||
test_config["layout"] = torch.eye( | ||
SEQ // block_size, SEQ // block_size, dtype=torch.long | ||
) | ||
test_config["block_size"] = block_size | ||
|
||
attention = build_attention(test_config) | ||
|
||
# build a multi head dispatch to test this attention mechanism | ||
multi_head = MultiHeadDispatch( | ||
seq_len=SEQ, | ||
dim_model=MODEL, | ||
num_heads=heads, | ||
attention=attention, | ||
residual_dropout=0.0, | ||
).to(device) | ||
|
||
# Check that a shuffled input produces the same results | ||
seqs = [SEQ, SEQ - 16] | ||
|
||
for seq in seqs: | ||
# Check that we can pass a smaller sequence | ||
inputs = torch.rand(BATCH, seq, MODEL, device=device) | ||
shuffle = torch.randperm(inputs.shape[1]) | ||
inputs_shuffled = inputs[:, shuffle, :].clone() | ||
|
||
results = multi_head(inputs, inputs, inputs) | ||
results_shuffled = multi_head(inputs_shuffled, inputs_shuffled, inputs_shuffled) | ||
|
||
torch.allclose(results[:, shuffle, :], results_shuffled) | ||
|
||
# Test the non-self-attention codepath | ||
att = multi_head(inputs, inputs_shuffled, inputs) | ||
|
||
# Check that dropout actually drops some values | ||
if attn_dropout > 0: | ||
att_2 = multi_head(inputs, inputs_shuffled, inputs) | ||
assert (att != att_2).any() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,7 +80,7 @@ | |
"eval_frequency": 50, | ||
"num_train_steps": 10000, | ||
"num_eval_steps": 62, | ||
"gradient_accumulation": 1 | ||
"gradient_accumulation": 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. compositional takes more memory, so this task does not pass on a 8GB gpu (desktop 3080) without bumping up the accumulation |
||
}, | ||
"model": { | ||
"pooling_mode": "mean", | ||
|
@@ -94,7 +94,7 @@ | |
}, | ||
"xformer": [ | ||
{ | ||
"reversible": true, | ||
"reversible": false, | ||
"block_type": "encoder", | ||
"num_layers": 2, | ||
"layer_norm_style": "pre", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,11 @@ def build_multi_head_attention( | |
"num_heads" | ||
] | ||
|
||
if "dim_model" not in multi_head_config["attention"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convenience, remove the need for field duplication in between attention and MHA |
||
multi_head_config["attention"]["dim_model"] = multi_head_config[ | ||
"dim_model" | ||
] | ||
|
||
if ( | ||
"dim_features" not in multi_head_config["attention"] | ||
or multi_head_config["attention"]["dim_features"] is None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ class AttentionMask: | |
""" | ||
|
||
def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False): | ||
assert additive_mask.is_floating_point() | ||
assert additive_mask.is_floating_point(), additive_mask.dtype | ||
assert not additive_mask.requires_grad | ||
|
||
if additive_mask.ndim == 2: | ||
|
@@ -49,7 +49,7 @@ def from_bool(cls: Type[Self], x: torch.Tensor) -> Self: | |
""" | ||
assert x.dtype == torch.bool | ||
|
||
additive_mask = torch.empty_like(x, dtype=torch.float) | ||
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's already doing that by default I believe, just making it a more explicit |
||
additive_mask.masked_fill_(x, 0.0) | ||
additive_mask.masked_fill_(~x, float("-inf")) | ||
|
||
|
@@ -62,7 +62,7 @@ def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self: | |
""" | ||
assert not x.dtype == torch.bool | ||
|
||
additive_mask = torch.empty_like(x, dtype=torch.float) | ||
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) | ||
x = x.bool() | ||
|
||
additive_mask.masked_fill_(x, 0.0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reducing the memory load, generating a new graph for everyone