Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for separate passing of causal flag and other types of bias #640

Open
jfc4050 opened this issue Jan 12, 2023 · 5 comments
Open

Comments

@jfc4050
Copy link
Contributor

jfc4050 commented Jan 12, 2023

馃殌 Feature

there's performance advantages to treating causal masking as a special case since it can be used by the kernel as a signal to skip computation. This is already supported by the CUTLASS implementation, its just not exposed in the python API

Motivation

here's some benchmarks showing additive bias and causal being used simultaneously, generally about 2x faster when causal is enabled

[----------------------- attention (attn_bias=<class 'torch.Tensor'>) -----------------------]
                                                                        |  optimized  |  eager
1 threads: -----------------------------------------------------------------------------------
      f16 B=4, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |     38.2    |   73.2
      b16 B=4, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |     38.8    |   73.5
      f16 B=4, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     19.7    |   73.1
      b16 B=4, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     20.1    |   73.6
      f16 B=8, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |     76.2    |
      b16 B=8, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |     77.6    |
      f16 B=8, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     39.3    |
      b16 B=8, M=4096, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     39.9    |
      f16 B=4, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |    151.9    |
      b16 B=4, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |    154.5    |
      f16 B=4, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     77.3    |
      b16 B=4, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |     78.6    |
      f16 B=8, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |    303.4    |
      b16 B=8, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=False  |    308.6    |
      f16 B=8, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |    154.0    |
      b16 B=8, M=8192, H=64, K=128, p=0.3,  BiasT=Tensor, Causal=True   |    156.7    |
[----------- attention backward (attn_bias=<class 'torch.Tensor'>) -----------]
                                                                       |
1 threads: --------------------------------------------------------------------
      f16 B=4, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  134.2
      b16 B=4, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  138.1
      f16 B=4, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |   71.4
      b16 B=4, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |   73.4
      f16 B=8, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  230.7
      b16 B=8, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  237.3
      f16 B=8, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  123.2
      b16 B=8, M=4096, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  126.8
      f16 B=4, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  529.6
      b16 B=4, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  545.5
      f16 B=4, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  275.5
      b16 B=4, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  283.2
      f16 B=8, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  916.9
      b16 B=8, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=False  |  937.3
      f16 B=8, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  474.3
      b16 B=8, M=8192, H=64, K=128, p=0.3, BiasT=Tensor, Causal=True   |  487.6

Pitch

here's a commit with a rough draft of what it may look like 7df2df0

def memory_efficient_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
    causal: bool = False,
    p: float = 0.0,
    scale: Optional[float] = None,
    *,
    op: Optional[AttentionOp] = None,

for now continuing to allow causal masks to be passed via the attn_bias parameter to preserve backward compatibility

one of the primary purposes of this issue is so we can discuss API though

Alternatives

Additional context

also see this PR #587

@danthe3rd
Copy link
Contributor

danthe3rd commented Jan 12, 2023

Thanks for opening this issue!
The idea of the attn_bias is to hold any information about bias to be added to the attention before the softmax. Technically, causality is a bias, so that's why it's passed as an attn_bias parameter.
We also want to reduce complexity in terms of API, and have a single bias argument, that's why we don't have a causal parameter at the moment.

Now I totally agree that we might want to combine multiple bias types together:

  • causality
  • arbitrary bias (torch.Tensor)
  • (soon) specific bias to use when different batch elements have a different sequence length (inputs tensors are effectively of shape [1, sum_i{M_i}, H, K])

We're not totally clear on how this would work out in terms of API, but we definitively want to support combinations of those, and ideally we would pass it through the attn_bias arg (I wanted to prototype of this once your PR is merged to avoid further merge conflits)

Also cc @fmassa

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 12, 2023

perhaps we could use a set of biases passed via attn_bias? that would work for the performance purposes mentioned above as well

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 12, 2023

at least for causal + arbitrary bias, implementations that don't support separately dealing with both could just sum them in the python layer, while implementations like CUTLASS would retain the necessary information to skip computation

@danthe3rd
Copy link
Contributor

perhaps we could use a set of biases passed via attn_bias? that would work for the performance purposes mentioned above as well

We were thinking about how we could make that obvious for the user, maybe using python operators (eg attn_bias=LowerTriangularMask() + my_tensor), which would be represented as a list of additive masks. We would also provide a to_tensor() method to make it easy to debug & for tests (alternatively, those bias classes could be subclasses of torch.Tensor)

danthe3rd pushed a commit that referenced this issue Jan 13, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
danthe3rd pushed a commit that referenced this issue Jan 16, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
danthe3rd pushed a commit that referenced this issue Jan 17, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
facebook-github-bot pushed a commit that referenced this issue Jan 19, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```

ghstack-source-id: 44740f71132fa76226fd4c559cc3f09732ff139b
Pull Request resolved: https://github.com/fairinternal/xformers/pull/435

__original_commit__ = fairinternal/xformers@be55fcd21c5dd621831245c5995e1c6fb49d9b77
@danthe3rd
Copy link
Contributor

@jfc4050 just merged d23d04e which adds support for fmha.attn_bias.LowerTriangularMaskWithTensorBias to do what you want. Backpropagation through the bias is not functional yet, as it's a bit more convoluted to implement

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants