Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Compositional attention #178

Merged
merged 8 commits into from Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion BENCHMARKS.md
Expand Up @@ -11,7 +11,7 @@ Please note that:
- These numbers are dependent of hyperparameters (dimensions chosen for Linformer, sparsity of the pattern), they are mostly an illustration
- The sparse attention patterns tested here are just presets, as explained in the linked notebook generating any new sparse attention pattern should be relatively easy, while keeping the benefits of optimized computations.

Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py --activations gelu --plot -emb 256 -bs 32 -heads 16`
Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py --activations gelu --plot -emb 256 -bs 8 -heads 4`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reducing the memory load, generating a new graph for everyone


![Memory use for different attentions](docs/plots/memory_vs_attention.png) ![Runtime for different attentions](docs/plots/runtime_vs_attention.png)

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Added
- Compositional Attention [#41]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hotfix, this was needed actually, not directly tied to this PR


## [0.0.8] - 2022-01-07
### Fixed
Expand Down
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -139,6 +139,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- See BigBird, Longformers,..
- [FourierMix](xformers/components/attention/fourier_mix.py)
- *[FNet: Mixing Tokens with Fourier Transforms, Lee-Thorp et al.](https://arxiv.org/abs/2105.03824v1)*
- [CompositionalAttention](xformers/components/attention/compositional.py)
- *[Compositional Attention: Disentangling search and retrieval, S. Mittal et al.](https://arxiv.org/pdf/2110.09419v1.pdf)*

- ... add a new one [see Contribution.md](CONTRIBUTING.md)

</p></details>
Expand Down
Binary file modified docs/plots/memory_vs_attention.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/runtime_vs_attention.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/microGPT.py
Expand Up @@ -68,6 +68,7 @@ def __init__(
"dropout": self.hparams.attn_pdrop,
"causal": True,
"seq_len": self.hparams.block_size,
"num_rules": self.hparams.n_head,
},
},
"feedforward_config": {
Expand Down
1 change: 1 addition & 0 deletions requirements-lra.txt
Expand Up @@ -4,5 +4,6 @@
tensorboard>=2.3.0
tensorflow>=2.3.1
tensorflow-datasets>=4.0.1
tensorflow-text>=2.7.3
submitit
fvcore
4 changes: 3 additions & 1 deletion tests/test_attentions.py
Expand Up @@ -43,10 +43,12 @@ def _get_multihead(
"dropout": attn_dropout,
"causal": causal,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1,
"window_size": SEQ // 8 + 1, # local attention
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to align the different field names and reformulate wherever possible, but there are still some specificities. It's more visible here since it's a specific attention unit test, but for real the new fields were already needed for the MHA and they don't need to be duplicated

"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL / heads,
"num_rules": 2, # Compositional Attention
}

if skip_output_projection:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_block_factory.py
Expand Up @@ -58,10 +58,12 @@ def test_xformer_encoder_block(
"window_size": SEQ // 8 + 1,
"seq_len": SEQ,
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL / heads,
"dim_head": MODEL // heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
}

multi_head_config = {
Expand Down Expand Up @@ -146,11 +148,11 @@ def test_xformer_decoder_block(
"causal": causal,
"window_size": SEQ // 8 + 1,
"seq_len": SEQ,
"dim_head": MODEL // heads,
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"num_heads": heads,
"dim_head": MODEL / heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
}

multi_head_config = {
Expand Down
113 changes: 113 additions & 0 deletions tests/test_compositional_attention.py
@@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from xformers.components import MultiHeadDispatch

# Automatically test all the registered attentions
from xformers.components.attention import (
_DENSITY_THRESHOLD,
ATTENTION_REGISTRY,
build_attention,
)

DEVICES = (
[torch.device("cpu")] if not torch.cuda.is_available() else [torch.device("cuda")]
)

BATCH = 2
SEQ = 128 if torch.cuda.is_available() else 32
MODEL = 128 if torch.cuda.is_available() else 64
GLOBAL_ATTENTION_RATIO = (
_DENSITY_THRESHOLD * 0.9
) # Make sure that we test the sparse implementation, no matter the threshold

assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered"


@pytest.mark.parametrize("attn_dropout", [0.0, 0.3])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code coverage, this attention exposes quite a few knobs which are mostly orthogonal to the other attentions, so I figured it was best to cover them in a dedicated unit test

@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("heads", [1, 4])
@pytest.mark.parametrize("rules", [1, 4])
@pytest.mark.parametrize("q_compose", [False, True])
@pytest.mark.parametrize("dim_selection", [MODEL // 2, None])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("qk_rule", [True, False])
@pytest.mark.parametrize("nonlinear", [True, False])
@pytest.mark.parametrize("device", DEVICES)
def test_build_and_run(
heads: int,
attn_dropout: float,
causal: bool,
rules: int,
q_compose: bool,
dim_selection: int,
bias: bool,
qk_rule: bool,
nonlinear: bool,
device: torch.device,
):

torch.manual_seed(42)

test_config = {
"name": "compositional",
"dropout": attn_dropout,
"causal": causal,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1, # local attention
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"num_rules": 2, # Compositional Attention
"q_compose": q_compose,
"rules": rules,
"dim_selection": dim_selection,
"bias": bias,
"qk_rule": qk_rule,
"nonlinear": nonlinear,
}

# Add some blocksparse layout to test the corresponding attention
block_size = 16
test_config["layout"] = torch.eye(
SEQ // block_size, SEQ // block_size, dtype=torch.long
)
test_config["block_size"] = block_size

attention = build_attention(test_config)

# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
seq_len=SEQ,
dim_model=MODEL,
num_heads=heads,
attention=attention,
residual_dropout=0.0,
).to(device)

# Check that a shuffled input produces the same results
seqs = [SEQ, SEQ - 16]

for seq in seqs:
# Check that we can pass a smaller sequence
inputs = torch.rand(BATCH, seq, MODEL, device=device)
shuffle = torch.randperm(inputs.shape[1])
inputs_shuffled = inputs[:, shuffle, :].clone()

results = multi_head(inputs, inputs, inputs)
results_shuffled = multi_head(inputs_shuffled, inputs_shuffled, inputs_shuffled)

torch.allclose(results[:, shuffle, :], results_shuffled)

# Test the non-self-attention codepath
att = multi_head(inputs, inputs_shuffled, inputs)

# Check that dropout actually drops some values
if attn_dropout > 0:
att_2 = multi_head(inputs, inputs_shuffled, inputs)
assert (att != att_2).any()
4 changes: 2 additions & 2 deletions xformers/benchmarks/LRA/code/config.json
Expand Up @@ -80,7 +80,7 @@
"eval_frequency": 50,
"num_train_steps": 10000,
"num_eval_steps": 62,
"gradient_accumulation": 1
"gradient_accumulation": 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compositional takes more memory, so this task does not pass on a 8GB gpu (desktop 3080) without bumping up the accumulation

},
"model": {
"pooling_mode": "mean",
Expand All @@ -94,7 +94,7 @@
},
"xformer": [
{
"reversible": true,
"reversible": false,
"block_type": "encoder",
"num_layers": 2,
"layer_norm_style": "pre",
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/__init__.py
Expand Up @@ -48,6 +48,11 @@ def build_multi_head_attention(
"num_heads"
]

if "dim_model" not in multi_head_config["attention"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convenience, remove the need for field duplication in between attention and MHA

multi_head_config["attention"]["dim_model"] = multi_head_config[
"dim_model"
]

if (
"dim_features" not in multi_head_config["attention"]
or multi_head_config["attention"]["dim_features"] is None
Expand Down
6 changes: 3 additions & 3 deletions xformers/components/attention/attention_mask.py
Expand Up @@ -24,7 +24,7 @@ class AttentionMask:
"""

def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
assert additive_mask.is_floating_point()
assert additive_mask.is_floating_point(), additive_mask.dtype
assert not additive_mask.requires_grad

if additive_mask.ndim == 2:
Expand All @@ -49,7 +49,7 @@ def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
"""
assert x.dtype == torch.bool

additive_mask = torch.empty_like(x, dtype=torch.float)
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already doing that by default I believe, just making it a more explicit

additive_mask.masked_fill_(x, 0.0)
additive_mask.masked_fill_(~x, float("-inf"))

Expand All @@ -62,7 +62,7 @@ def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
"""
assert not x.dtype == torch.bool

additive_mask = torch.empty_like(x, dtype=torch.float)
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
x = x.bool()

additive_mask.masked_fill_(x, 0.0)
Expand Down
8 changes: 7 additions & 1 deletion xformers/components/attention/base.py
Expand Up @@ -11,6 +11,8 @@
import torch
import torch.nn as nn

from xformers.components.attention import AttentionMask


@dataclass
class AttentionConfig:
Expand All @@ -29,7 +31,7 @@ class AttentionConfig:
class Attention(nn.Module, metaclass=ABCMeta):
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""

_causal_mask: Optional[torch.Tensor] = None
_causal_mask: Optional[AttentionMask] = None

@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
Expand All @@ -47,6 +49,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False

# Whether the attention owns the single head/multihead mechanism
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
Expand Down