# How does FA2 `flash_attn_varlen_func` handle masks?

In this notebook, we want to know whether packed sequences passed to `flash_attn_varlen_func` are independent from each other. Confirming the issue: https://github.com/Dao-AILab/flash-attention/issues/654.

If masks are applied correctly, we would get the same result whether we use a packed sequence with `flash_attn_varlen_func` or use multiple sequences with `flash_attn_func`. We won't use any padding.

In [1]:
from nbdev.showdoc import *
from fastcore.all import *

import torch
from flash_attn import flash_attn_func, flash_attn_varlen_func

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16

Here are the doc strings for `flash_attn_func` and `flash_attn_varlen_func`. What are the differences? 

1. **Sequence handling**:
   - `flash_attn_func`: Processes regular batched sequences where each sequence has the same length (padded if necessary)
   - `flash_attn_varlen_func`: Handles variable-length sequences packed into a single tensor, avoiding computation on padding tokens

2. **Input format**:
   - `flash_attn_func`: Takes inputs in shape [batch_size, seq_len, num_heads, head_dim]
   - `flash_attn_varlen_func`: Takes inputs in shape [total_tokens, num_heads, head_dim] where total_tokens is the sum of all sequence lengths

3. **Additional parameters for `flash_attn_varlen_func`**:
   - Requires `cu_seqlens` (cumulative sequence lengths) to mark sequence boundaries
   - Requires `max_seqlen` to know the maximum sequence length in the batch

4. **Efficiency**:
   - `flash_attn_varlen_func` is more memory-efficient for batches with varying sequence lengths
   - `flash_attn_func` is simpler to use when all sequences have the same length

5. **Use cases**:
   - `flash_attn_func`: Better for training with fixed sequence lengths or when using padding
   - `flash_attn_varlen_func`: Better for inference or when handling many sequences of different lengths

Both functions implement the same core attention algorithm with the same optimizations, but `flash_attn_varlen_func` adds the ability to handle variable-length sequences more efficiently.

In [3]:
show_doc(flash_attn_func)

---

### flash_attn_func

>      flash_attn_func (q, k, v, dropout_p=0.0, softmax_scale=None,
>                       causal=False, window_size=(-1, -1), softcap=0.0,
>                       alibi_slopes=None, deterministic=False,
>                       return_attn_probs=False)

*dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
    1 1 1 1 0
    1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
    0 0
    0 0
    0 0
    1 0
    1 1
If the row of the mask is all zero, the output will be zero.

If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Arguments:
    q: (batch_size, seqlen, nheads, headdim)
    k: (batch_size, seqlen, nheads_k, headdim)
    v: (batch_size, seqlen, nheads_k, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
        (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
        is added to the attention score of query i and key j.
    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
        which is slightly slower and uses more memory. The forward pass is always deterministic.
    return_attn_probs: bool. Whether to return the attention probabilities. This option is for
       testing only. The returned probabilities are not guaranteed to be correct
       (they might not have the right scaling).
Return:
    out: (batch_size, seqlen, nheads, headdim).
    softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
        logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
        normalization factor).
    S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
        The output of softmax (possibly with different scaling). It also encodes the dropout
        pattern (negative means that location was dropped, nonnegative means it was kept).*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| q |  |  |  |
| k |  |  |  |
| v |  |  |  |
| dropout_p | float | 0.0 |  |
| softmax_scale | NoneType | None |  |
| causal | bool | False |  |
| window_size | tuple | (-1, -1) | -1 means infinite context window |
| softcap | float | 0.0 | 0.0 means deactivated |
| alibi_slopes | NoneType | None |  |
| deterministic | bool | False |  |
| return_attn_probs | bool | False |  |

In [4]:
show_doc(flash_attn_varlen_func)

---

### flash_attn_varlen_func

>      flash_attn_varlen_func (q, k, v, cu_seqlens_q, cu_seqlens_k,
>                              max_seqlen_q, max_seqlen_k, dropout_p=0.0,
>                              softmax_scale=None, causal=False,
>                              window_size=(-1, -1), softcap=0.0,
>                              alibi_slopes=None, deterministic=False,
>                              return_attn_probs=False, block_table=None)

*dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
    1 1 1 1 0
    1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
    0 0
    0 0
    0 0
    1 0
    1 1
If the row of the mask is all zero, the output will be zero.

If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Arguments:
    q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
    k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
    v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
    cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
       of the sequences in the batch, used to index into q.
    cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
       of the sequences in the batch, used to index into kv.
    max_seqlen_q: int. Maximum query sequence length in the batch.
    max_seqlen_k: int. Maximum key sequence length in the batch.
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
    softcap: float. Anything > 0 activates softcapping attention.
    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
        (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
        is added to the attention score of query i and key j.
    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
        which is slightly slower and uses more memory. The forward pass is always deterministic.
    return_attn_probs: bool. Whether to return the attention probabilities. This option is for
       testing only. The returned probabilities are not guaranteed to be correct
       (they might not have the right scaling).
Return:
    out: (total, nheads, headdim).
    softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
        logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
        normalization factor).
    S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
        The output of softmax (possibly with different scaling). It also encodes the dropout
        pattern (negative means that location was dropped, nonnegative means it was kept).*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| q |  |  |  |
| k |  |  |  |
| v |  |  |  |
| cu_seqlens_q |  |  |  |
| cu_seqlens_k |  |  |  |
| max_seqlen_q |  |  |  |
| max_seqlen_k |  |  |  |
| dropout_p | float | 0.0 |  |
| softmax_scale | NoneType | None |  |
| causal | bool | False |  |
| window_size | tuple | (-1, -1) | -1 means infinite context window |
| softcap | float | 0.0 | 0.0 means deactivated |
| alibi_slopes | NoneType | None |  |
| deterministic | bool | False |  |
| return_attn_probs | bool | False |  |
| block_table | NoneType | None |  |

## `flash_attn_func`

So, the question is, do they produce the same result? We will use `flash_attn_func` on two sequences.

In [5]:
torch.manual_seed(42)

# Setup example data: 2 sequences
batch_size, n_heads, head_dim = 1, 1, 8
seq_lens = [5, 3]  # First sequence has 5 tokens, second has 3

# Create query, key, value tensors for each sequence
q1 = torch.randn(1, seq_lens[0], n_heads, head_dim, device=device, dtype=dtype)
k1 = torch.rand_like(q1)
v1 = torch.rand_like(q1)
q1.shape

torch.Size([1, 5, 1, 8])

The shape is `bs, seq_lens, n_heads, head_dim` for q, k, and v.

In [6]:
q2 = torch.randn(1, seq_lens[1], n_heads, head_dim, device=device, dtype=dtype)
k2 = torch.rand_like(q2)
v2 = torch.rand_like(q2)
q2.shape

torch.Size([1, 3, 1, 8])

Let's make changes to q, k, and v so that we can have different q,k,v values

In [7]:
for i in range(seq_lens[0]):
    k1[0, i, 0] = k1[0, i, 0] + 1
    v1[0, i, 0] = v1[0, i, 0] + i
q1, k1, v1

(tensor([[[[ 0.1943,  2.1562, -0.1719,  0.8477, -1.9219,  0.6523, -0.6484,
            -0.8164]],
 
          [[ 0.5273, -1.2734, -1.6641, -0.3027, -0.0928,  0.1992, -1.1172,
             1.8594]],
 
          [[-0.7148,  0.6875,  0.7969, -0.0334,  1.4922, -0.5156, -0.2539,
             1.4766]],
 
          [[-0.3262, -1.1562,  2.3594, -0.6914,  0.1836, -1.1797, -1.8047,
            -1.5781]],
 
          [[ 0.8398,  1.4219,  0.6484,  0.4258, -1.5859,  0.6211,  1.6875,
            -0.6641]]]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[[[1.9844, 1.1250, 1.5625, 1.5234, 1.7500, 1.5938, 1.9688, 1.8984]],
 
          [[1.7734, 1.6719, 1.5469, 1.5078, 1.5938, 1.2578, 1.8906, 1.1875]],
 
          [[1.7188, 1.1797, 1.9844, 1.6250, 1.4844, 1.7344, 1.2734, 1.9062]],
 
          [[1.6562, 1.6484, 1.3281, 1.0234, 1.1406, 1.0625, 1.4141, 1.5781]],
 
          [[1.8203, 1.6953, 1.5781, 1.0703, 1.7344, 1.5156, 1.6875, 1.3516]]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[[

In [8]:
for i in range(seq_lens[1]):
    k2[0, i, 0] = k2[0, i, 0] + 1
    v2[0, i, 0] = v2[0, i, 0] + i
q2, k2, v2

(tensor([[[[-0.6992, -1.8672, -0.8828, -1.6641, -0.4316,  0.9492,  0.6602,
             0.0447]],
 
          [[ 0.5703,  2.1875, -0.2471, -1.3828,  0.0603, -0.2432,  1.3203,
             0.5195]],
 
          [[-0.6094,  0.1001, -0.8945, -0.9375, -0.2656,  1.5312,  0.5586,
            -0.9453]]]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[[[1.3281, 2.0000, 1.0703, 1.2891, 1.5000, 1.1406, 1.5781, 1.8281]],
 
          [[1.4219, 1.1172, 1.2734, 1.5469, 1.9844, 1.1797, 1.5547, 1.2969]],
 
          [[1.3828, 1.4062, 1.5625, 1.8203, 1.0156, 1.0625, 1.9609, 1.1719]]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[[0.3145, 0.3066, 0.6133, 0.2256, 0.1050, 0.2305, 0.2070, 0.0510]],
 
          [[1.2109, 1.7266, 1.9688, 1.1953, 1.4375, 1.0703, 1.4688, 1.9062]],
 
          [[2.2500, 2.7188, 2.9219, 2.6562, 2.2969, 2.4219, 2.1250, 2.6562]]]],
        device='cuda:0', dtype=torch.bfloat16))

In [9]:
out1 = flash_attn_func(q1, k1, v1, causal=True)
out2 = flash_attn_func(q2, k2, v2, causal=True)
out1

tensor([[[[0.8242, 0.2129, 0.7305, 0.3203, 0.7539, 0.2129, 0.9297, 0.1445]],

         [[0.9922, 0.7930, 1.0000, 0.6211, 1.1406, 0.5625, 1.2266, 0.5898]],

         [[1.4688, 1.6484, 1.6641, 1.5469, 1.7031, 1.4844, 1.7344, 1.2969]],

         [[2.0469, 2.3125, 2.1719, 2.3594, 2.3438, 2.2188, 2.3281, 2.1094]],

         [[2.3125, 2.4688, 2.4062, 2.5625, 2.6406, 2.5312, 2.6875, 2.4688]]]],
       device='cuda:0', dtype=torch.bfloat16)

In [10]:
out2

tensor([[[[0.3145, 0.3066, 0.6133, 0.2256, 0.1050, 0.2305, 0.2070, 0.0510]],

         [[0.5703, 0.7148, 1.0000, 0.5039, 0.4863, 0.4707, 0.5703, 0.5820]],

         [[1.2422, 1.5547, 1.8125, 1.3359, 1.2578, 1.2188, 1.2422, 1.5078]]]],
       device='cuda:0', dtype=torch.bfloat16)

## `flash_attn_varlen_func`

To use `flash_attn_varlen_func`, we have to pack the sequences, find `cu_seqlens` (cumulative sequence lengths), and `max_seq_len`.

In [11]:
# Setup for packed sequences
batch_size, n_heads, head_dim = 1, 1, 8
seq_lens = [5, 3]
total_tokens = sum(seq_lens)

# Create cumulative sequence lengths tensor
cu_seqlens = torch.tensor([0, seq_lens[0], seq_lens[0] + seq_lens[1]], dtype=torch.int32, device=device) # has to be int32
cu_seqlens

tensor([0, 5, 8], device='cuda:0', dtype=torch.int32)

In [12]:
max_seq_len = max(seq_lens)
max_seq_len

5

We no longer have batch size in the first dimension.

In [13]:
q = torch.cat([q1, q2], dim=1).squeeze(0)
k = torch.cat([k1, k2], dim=1).squeeze(0)
v = torch.cat([v1, v2], dim=1).squeeze(0)
q.shape, k.shape, v.shape

(torch.Size([8, 1, 8]), torch.Size([8, 1, 8]), torch.Size([8, 1, 8]))

In [14]:
q, k, v

(tensor([[[ 0.1943,  2.1562, -0.1719,  0.8477, -1.9219,  0.6523, -0.6484,
           -0.8164]],
 
         [[ 0.5273, -1.2734, -1.6641, -0.3027, -0.0928,  0.1992, -1.1172,
            1.8594]],
 
         [[-0.7148,  0.6875,  0.7969, -0.0334,  1.4922, -0.5156, -0.2539,
            1.4766]],
 
         [[-0.3262, -1.1562,  2.3594, -0.6914,  0.1836, -1.1797, -1.8047,
           -1.5781]],
 
         [[ 0.8398,  1.4219,  0.6484,  0.4258, -1.5859,  0.6211,  1.6875,
           -0.6641]],
 
         [[-0.6992, -1.8672, -0.8828, -1.6641, -0.4316,  0.9492,  0.6602,
            0.0447]],
 
         [[ 0.5703,  2.1875, -0.2471, -1.3828,  0.0603, -0.2432,  1.3203,
            0.5195]],
 
         [[-0.6094,  0.1001, -0.8945, -0.9375, -0.2656,  1.5312,  0.5586,
           -0.9453]]], device='cuda:0', dtype=torch.bfloat16),
 tensor([[[1.9844, 1.1250, 1.5625, 1.5234, 1.7500, 1.5938, 1.9688, 1.8984]],
 
         [[1.7734, 1.6719, 1.5469, 1.5078, 1.5938, 1.2578, 1.8906, 1.1875]],
 
         [[1.7188, 

In [15]:
varlen_out = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, 
                            max_seq_len, max_seq_len, 0.0, causal=True)
varlen_out

tensor([[[0.8242, 0.2129, 0.7305, 0.3203, 0.7539, 0.2129, 0.9297, 0.1445]],

        [[0.9922, 0.7930, 1.0000, 0.6211, 1.1406, 0.5625, 1.2266, 0.5898]],

        [[1.4688, 1.6484, 1.6641, 1.5469, 1.7031, 1.4844, 1.7344, 1.2969]],

        [[2.0469, 2.3125, 2.1719, 2.3594, 2.3438, 2.2188, 2.3281, 2.1094]],

        [[2.3125, 2.4688, 2.4062, 2.5625, 2.6406, 2.5312, 2.6875, 2.4688]],

        [[0.3145, 0.3066, 0.6133, 0.2256, 0.1050, 0.2305, 0.2070, 0.0510]],

        [[0.5703, 0.7148, 1.0000, 0.5039, 0.4863, 0.4707, 0.5703, 0.5820]],

        [[1.2422, 1.5547, 1.8125, 1.3359, 1.2578, 1.2188, 1.2422, 1.5078]]],
       device='cuda:0', dtype=torch.bfloat16)

Now let's compare it with results from `flash_attn_func`. 

In [16]:
out1.shape

torch.Size([1, 5, 1, 8])

In [17]:
out2.shape

torch.Size([1, 3, 1, 8])

We match the shape first.

In [18]:
packed_out = torch.cat([out1, out2], dim=1).squeeze(0)
packed_out.shape

torch.Size([8, 1, 8])

We can create a small function to check equality of two tensors. Under the hood, we use `test_eq` from fastcore with some changes to dtype and device.

In [19]:
def test_eq_cuda(a, b, rtol=1e-5, atol=1e-8):
    "Test equality of tensors, handling CUDA tensors and special dtypes"
    a_cpu = a.cpu() if hasattr(a, 'cpu') else a
    b_cpu = b.cpu() if hasattr(b, 'cpu') else b
    
    # Convert to float32 if needed
    if hasattr(a_cpu, 'dtype') and str(a_cpu.dtype) == 'bfloat16':
        a_cpu = a_cpu.float()
    if hasattr(b_cpu, 'dtype') and str(b_cpu.dtype) == 'bfloat16':
        b_cpu = b_cpu.float()
    
    # Use torch's isclose for tensor comparison instead of test_eq
    if hasattr(a_cpu, 'shape') and hasattr(b_cpu, 'shape'):
        assert a_cpu.shape == b_cpu.shape, f"Shapes don't match: {a_cpu.shape} vs {b_cpu.shape}"
        assert torch.allclose(a_cpu, b_cpu, rtol=rtol, atol=atol), f"Tensors not close enough:\n{a_cpu}\n{b_cpu}"
    else:
        test_eq(a_cpu, b_cpu)

In [20]:
test_eq_cuda(varlen_out, packed_out)

Test passed! Here are the outputs from each approaches. We can just take a look at them as they are small.

In [21]:
varlen_out

tensor([[[0.8242, 0.2129, 0.7305, 0.3203, 0.7539, 0.2129, 0.9297, 0.1445]],

        [[0.9922, 0.7930, 1.0000, 0.6211, 1.1406, 0.5625, 1.2266, 0.5898]],

        [[1.4688, 1.6484, 1.6641, 1.5469, 1.7031, 1.4844, 1.7344, 1.2969]],

        [[2.0469, 2.3125, 2.1719, 2.3594, 2.3438, 2.2188, 2.3281, 2.1094]],

        [[2.3125, 2.4688, 2.4062, 2.5625, 2.6406, 2.5312, 2.6875, 2.4688]],

        [[0.3145, 0.3066, 0.6133, 0.2256, 0.1050, 0.2305, 0.2070, 0.0510]],

        [[0.5703, 0.7148, 1.0000, 0.5039, 0.4863, 0.4707, 0.5703, 0.5820]],

        [[1.2422, 1.5547, 1.8125, 1.3359, 1.2578, 1.2188, 1.2422, 1.5078]]],
       device='cuda:0', dtype=torch.bfloat16)

In [22]:
packed_out

tensor([[[0.8242, 0.2129, 0.7305, 0.3203, 0.7539, 0.2129, 0.9297, 0.1445]],

        [[0.9922, 0.7930, 1.0000, 0.6211, 1.1406, 0.5625, 1.2266, 0.5898]],

        [[1.4688, 1.6484, 1.6641, 1.5469, 1.7031, 1.4844, 1.7344, 1.2969]],

        [[2.0469, 2.3125, 2.1719, 2.3594, 2.3438, 2.2188, 2.3281, 2.1094]],

        [[2.3125, 2.4688, 2.4062, 2.5625, 2.6406, 2.5312, 2.6875, 2.4688]],

        [[0.3145, 0.3066, 0.6133, 0.2256, 0.1050, 0.2305, 0.2070, 0.0510]],

        [[0.5703, 0.7148, 1.0000, 0.5039, 0.4863, 0.4707, 0.5703, 0.5820]],

        [[1.2422, 1.5547, 1.8125, 1.3359, 1.2578, 1.2188, 1.2422, 1.5078]]],
       device='cuda:0', dtype=torch.bfloat16)