In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
from flash_attn.flash_attn_interface import (
    _flash_attn_varlen_forward,
    _flash_attn_varlen_backward,
)
from typing import Tuple, Optional
import torch

In [3]:
head_num = 16
dim = 128
seq_len = 100
chunk_size = 5
batch_size = 1

In [4]:
# seqlen, 3, nheads, d
q = torch.randn(seq_len, head_num, dim).cuda().to(torch.bfloat16)
k = torch.randn(seq_len, head_num, dim).cuda().to(torch.bfloat16)
v = torch.randn(seq_len, head_num, dim).cuda().to(torch.bfloat16)

In [5]:
attention_mask = torch.ones((batch_size, seq_len))
attention_mask

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [6]:
cu_seqlens = torch.tensor([  0, 100], dtype = torch.int32).cuda()
block_out_full, _, _, _, _, _, _, _ = _flash_attn_varlen_forward(
    q, k, v, cu_seqlens, cu_seqlens, 100, 100, 0.0, 1.0,
                          causal = True, window_size=(-1, -1),
    alibi_slopes=None, return_softmax = False, block_table = None)

In [7]:
import torch.nn.functional as F

def flatten_varlen_lse(lse, cu_seqlens):
    new_lse = []
    for i in range(len(cu_seqlens) - 1):
        start, end = cu_seqlens[i], cu_seqlens[i + 1]
        new_lse.append(lse[i, :, : end - start])
    return torch.cat(new_lse, dim=1)

def _update_out_and_lse(
    out: torch.Tensor,
    lse: torch.Tensor,
    block_out: torch.Tensor,
    block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

    block_out = block_out.to(torch.float32)
    block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

    # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
    # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
    # For additional context and discussion, please refer to:
    # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
    out = out - F.sigmoid(block_lse - lse) * (out - block_out)
    lse = lse - F.logsigmoid(lse - block_lse)

    return out, lse

def update_out_and_lse(
    out: Optional[torch.Tensor],
    lse: Optional[torch.Tensor],
    block_out: torch.Tensor,
    block_lse: torch.Tensor,
    slice_=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if out is None:
        if slice_ is not None:
            raise RuntimeError("first update_out_and_lse should not pass slice_ args")
        out = block_out.to(torch.float32)
        lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
    elif slice_ is not None:
        slice_out, slice_lse = out[slice_], lse[slice_]
        slice_out, slice_lse = _update_out_and_lse(
            slice_out, slice_lse, block_out, block_lse
        )
        out[slice_], lse[slice_] = slice_out, slice_lse
    else:
        out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
    return out, lse

## Causal

In [8]:
q_chunks = q.chunk(chunk_size, dim = 0)
k_chunks = k.chunk(chunk_size, dim = 0)
v_chunks = v.chunk(chunk_size, dim = 0)
seq_chunk = seq_len // chunk_size

In [9]:
out = None
lse = None

q_ = q_chunks[0]
cu_seqlens = torch.tensor([  0, 20], dtype = torch.int32).cuda()

for i in range(len(q_chunks)):
    block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
        q_, k_chunks[i], v_chunks[i], cu_seqlens, cu_seqlens, 100, 100, 0.0, 1.0,
                          causal = True and i == 0, window_size=(-1, -1),
    alibi_slopes=None, return_softmax = False, block_table = None)
    if block_lse.dim() == 3:
        old_lse = True
        block_lse = flatten_varlen_lse(
            block_lse,
            cu_seqlens=cu_seqlens,
        )
    out, lse = update_out_and_lse(out, lse, block_out, block_lse)

In [10]:
block_out_full[:20].argmax(-1)

tensor([[ 35,  25, 111,  18,  75,  96,  89,  51,  56,  48, 125,  10, 116,  11,
          38,  72],
        [ 27,  25, 122,  18,  75,  96,  89, 102,  81,  93, 110,  10,  47,  63,
          38,  72],
        [ 35,  96, 122,  18,  80,  65,  85, 119,  56,  97, 110,  10,   7,  63,
          71,  21],
        [ 82,  96,  34,  92, 114,  19,  27,  51, 121,   4, 112, 126,   7,  21,
          71,  66],
        [ 27, 107,  88,  99,  75,  14,  89, 119,  81,  93,  34,  45,   7,  11,
         112,  77],
        [ 82,  13,  88,  18,  12,  80,  81, 119,  59,  20,  34,  45, 116,  21,
         103,  77],
        [ 33, 120, 115,  48,  75,  65, 108, 119, 113,  93,  34,  52, 116,  11,
         103,  77],
        [ 42, 114,  22,  64,  49,  20, 108,   9,  56,  68, 125,  10,   7,  87,
         103, 105],
        [ 42, 120, 115,  76,  64, 117,  27, 102,  59, 110,  34,  14,  47,  88,
         112,  72],
        [ 42,  18,  34,  92,  77, 121,  89,  36, 106,  93,  90,  75,  47,  33,
         126,  72],
        [ 

In [11]:
out.argmax(-1)

tensor([[ 14,  88,  93,   7, 127, 108,  99,  52,  40, 113,  41, 125,  16,  23,
         123, 113],
        [ 57,  98,  23,  65, 119,   6,  44,  29,  56,  10,  14,  94,  42,  43,
          19,  82],
        [ 35,  55,  23,  19,  70,  55,  44,  94,  34,  39,   3,  73,  89,  50,
          89,  39],
        [116,  59,  81,  17,  75,  29,  20,  88,  56,  43,   7, 126,  87,  43,
         112,  51],
        [ 71,  72,  73,  96,  12, 125,  26,   8,  54,  86, 120,  73,  66, 115,
         122,  80],
        [  1,  54,  37, 125,  12, 103,  54, 107,  16, 118,   1, 110,   4,  35,
          10, 107],
        [ 11,  99,  55,  14, 112,  41,  42,  11, 126,  30,  23,  26,  64, 105,
           7,  11],
        [ 34,  37, 101,  59,  70,  55,  66, 111,  58,  68, 125,  81,   7,  93,
          70,   3],
        [ 85,  10,  21,  18,  41, 123,  37,  62,  43,  66,  52,  66,  63,  54,
          73,  20],
        [ 18,  25,   1,  13,  87,  79, 121, 104,  85,  13,  90,   1,  96,  18,
          49,  95],
        [ 

In [12]:
torch.mean((block_out_full[:20].argmax(-1) == out.argmax(-1)).float())

tensor(0.0969, device='cuda:0')

## Non Causal

In [13]:
cu_seqlens = torch.tensor([  0, 100], dtype = torch.int32).cuda()
block_out_full_noncausal, _, _, _, _, _, _, _ = _flash_attn_varlen_forward(
    q, k, v, cu_seqlens, cu_seqlens, 100, 100, 0.0, 1.0,
                          causal = False, window_size=(-1, -1),
    alibi_slopes=None, return_softmax = False, block_table = None)

In [14]:
out = None
lse = None

q_ = q_chunks[0]
cu_seqlens = torch.tensor([  0, 20], dtype = torch.int32).cuda()

for i in range(len(q_chunks)):
    block_out, _, _, _, _, block_lse, _, _ = _flash_attn_varlen_forward(
        q_, k_chunks[i], v_chunks[i], cu_seqlens, cu_seqlens, 100, 100, 0.0, 1.0,
                          causal = False, window_size=(-1, -1),
    alibi_slopes=None, return_softmax = False, block_table = None)
    if block_lse.dim() == 3:
        old_lse = True
        block_lse = flatten_varlen_lse(
            block_lse,
            cu_seqlens=cu_seqlens,
        )
    out, lse = update_out_and_lse(out, lse, block_out, block_lse)

In [15]:
block_out_full_noncausal[:20].argmax(-1)

tensor([[ 14,  96,  26,  76, 127, 108,  99,  52,  40, 113,  41, 125,  16,  23,
         123,  60],
        [ 57,  98,  23, 115, 119,  19,  44,  29,  56,  10,  14,  94,  42,  43,
          19,  82],
        [ 41,  15,  23,  19,  70,  55,  44,  94,  34,   8,  32,  73,  21,  50,
          89,  39],
        [116,  59,  81,  17,  75,  29,  20,  88,  56,  43,   7, 126,  87,  43,
         112,  51],
        [ 71,  72,  73,  96,  12, 125,  26,   8,  54,  86, 120,  73,  66, 115,
         122,  80],
        [  1,  54,  37, 125,  12, 103,  54, 107,  16, 118,   1, 110,  61,  47,
          10, 107],
        [ 11,  99,  55,  14,  17,  41,  42,  11,   4,  30,  23,  26,  64, 105,
           7,  11],
        [ 34, 100, 101,  59,  70,  55,  66, 111,  58,  68, 125,  81,   7,  93,
          70,   3],
        [ 85,  10,  21,  18,  41, 123,  37,  62,  43,  66,  52,  66,  63,  54,
          73,  20],
        [ 18,  25,   1,  13,  87,  79,  74, 104,  34,  13,  90,   1,  96,  18,
          49,  95],
        [ 

In [16]:
out.argmax(-1)

tensor([[ 14,  96,  26,  76, 127, 108,  99,  52,  40, 113,  41, 125,  16,  23,
         123,  60],
        [ 57,  98,  23, 115, 119,  19,  44,  29,  56,  10,  14,  94,  42,  43,
          19,  82],
        [ 41,  15,  23,  19,  70,  55,  44,  94,  34,   8,  32,  73,  21,  50,
          89,  39],
        [116,  59,  81,  17,  75,  29,  20,  88,  56,  43,   7, 126,  87,  43,
         112,  51],
        [ 71,  72,  73,  96,  12, 125,  26,   8,  54,  86, 120,  73,  66, 115,
         122,  80],
        [  1,  54,  37, 125,  12, 103,  54, 107,  16, 118,   1, 110,  61,  47,
          10, 107],
        [ 11,  99,  55,  14,  17,  41,  42,  11, 126,  30,  23,  26,  64, 105,
           7,  11],
        [ 34, 100, 101,  59,  70,  55,  66, 111,  58,  68, 125,  81,   7,  93,
          70,   3],
        [ 85,  10,  21,  18,  41, 123,  37,  62,  43,  66,  52,  66,  63,  54,
          73,  20],
        [ 18,  25,   1,  13,  87,  79,  74, 104,  34,  13,  90,   1,  96,  18,
          49,  95],
        [ 

In [17]:
torch.mean((block_out_full_noncausal[:20].argmax(-1) == out.argmax(-1)).float())

tensor(0.9969, device='cuda:0')