In [2]:
import os
from IPython.core.debugger import set_trace

# os.environ["TRITON_INTERPRET"] = "1"  # needs to be set *before* triton is imported


def check_tensors_gpu_ready(*tensors):
    for t in tensors:
        assert t.is_contiguous, "A tensor is not contiguous"
        if not os.environ.get("TRITON_INTERPRET") == "1":
            assert t.is_cuda, "A tensor is not on cuda"


def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Test if condition on pids are fulfilled
    E.g.:
        '=0'  checks that pid_0 == 0
        ',>1' checks that pid_1 > 1
        '>1,=0' checks that pid_0 > 1 and pid_1 == 0
    """
    pids = pid_0[0], pid_1[0], pid_2[0]
    conds = conds.replace(" ", "").split(",")
    for i, (cond, pid) in enumerate(zip(conds, pids)):
        if cond == "":
            continue
        op, threshold = cond[0], int(cond[1:])
        if op not in ["<", ">", ">=", "<=", "=", "!="]:
            raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{condition}'.")
        op = "==" if op == "=" else op
        if not eval(f"{pid} {op} {threshold}"):
            return False
    return True


assert test_pid_conds("")
assert test_pid_conds(">0", [1], [1])
assert not test_pid_conds(">0", [0], [1])
assert test_pid_conds("=0,=1", [0], [1], [0])


def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Stop kernel, if any condition of pids is fulfilled"""
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        set_trace()


def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    """Print txt, if any condition of pids is fulfilled"""
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        print(txt)


def cdiv(a, b):
    return (a + b - 1) // b


assert cdiv(10, 2) == 5
assert cdiv(10, 3) == 4

In [3]:
import os

# os.environ["TRITON_INTERPRET"] = "1"


import torch
from torch.nn import functional as F
import triton
import triton.language as tl
from typing import Optional

from attn_torch import torch_scaled_dot_product_attention

torch.set_printoptions(sci_mode=False)

## Pytorch implementation of FlashAttention
I follow the notation of Algorithm 1: https://arxiv.org/pdf/2205.14135.pdf

Helpful resources:
- [triton implementation of FA1](https://github.com/openai/triton/blob/fdf1c1f2a1f4de37ce1fb31316d53004d6e7e98c/python/tutorials/06-fused-attention.py)

In [4]:
def pad_to(x, dim, size, value=0.0):
    """Append padding to the input tensor x to match the target size along the given dimension."""
    pad_size = size - x.size(dim)
    if pad_size > 0:
        pad_dims = list(x.shape)
        pad_dims[dim] = pad_size
        pad = torch.full(pad_dims, value, dtype=x.dtype, device=x.device)
        x = torch.cat([x, pad], dim=dim)
    return x


@torch.no_grad()
def torch_flash_attention_kernel(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    softmax_scale: Optional[float] = None,
    B_r: int = 128,
    B_c: int = 128,
):
    """Flash attention kernel implementation using torch operations.

    This implementation closely follows Algorithm 1 in the FlashAttention paper: https://arxiv.org/pdf/2205.14135.pdf.
    The only difference is that we perform the attention scaling by sqrt(d) as part of the computation.

    This implementation is not intended to be used; it is only for reference and testing purposes.

    Args:
        Q: Queries tensor of shape [Z, H, N, D]
        K: Keys tensor of shape [Z, H, N, D]
        V: Values tensor of shape [Z, H, N, D]
        B_r: The block size for the rows
        B_c: The block size for the columns
    """

    Z, H, N, D = Q.shape
    dtype = Q.dtype
    device = Q.device

    softmax_scale = softmax_scale or 1.0 / (D**0.5)

    def inner(Q, K, V):

        # 2. Initialize O in HBM
        O = torch.zeros(N, D, device=device, dtype=dtype)  # [N, D]

        # 3. Divide Q into T_r blocks of size [B_r, D] each
        T_r = cdiv(N, B_r)
        Q = list(torch.split(Q, B_r))  # [T_r, B_r, D]

        # 3. Divide K, V into T_c blocks of size [B_c, d] each
        T_c = cdiv(N, B_c)
        K = list(torch.split(K, B_c))  # [T_c, B_c, D]
        V = list(torch.split(V, B_c))  # [T_c, B_c, D]

        # 4. Divide O into T_r blocks of size [B_r, D] each
        O = list()

        # 7. Outer loop (NOTE: in Algorithm 1, this is the inner loop)
        for i in range(T_r):
            # 8. Load Q_i, O_i, l_i, m_i into SRAM
            Q_i = Q[i]  # [B_r, D]
            Q_i = pad_to(Q_i, 0, B_r)  # simulate padding

            # 2. and 4. Divide l, m into T_r blocks of size [B_r] each
            l_i = torch.zeros(B_r, device=device, dtype=dtype)  # [B_r]
            m_i = torch.full((B_r,), float("-inf"), device=device, dtype=dtype)  # [B_r]
            O_i = torch.zeros(B_r, D, device=device, dtype=dtype)  # [B_r, B_c]

            # 5. Inner loop (NOTE: in Algorithm 1, this is the outer loop)
            for j in range(T_c):
                # 6. Load K_j, V_j into SRAM
                K_j = K[j]  # [B_c, d]
                V_j = V[j]  # [B_c, d]

                K_j = pad_to(K_j, 0, B_c)  # simulate padding
                V_j = pad_to(V_j, 0, B_c)  # simulate padding

                # 9. On chip, compute S_ij = Q_i @ K_j^T
                S_ij = Q_i @ K_j.T  # [B_r, B_c]

                # 9a. Scale by sqrt(d) (not in the paper, but part of the attention formula)
                S_ij = S_ij * softmax_scale

                # 9b. Mask out-of-bounds elements
                S_ij = torch.where(torch.arange(B_c, device=device).unsqueeze(0) + j * B_c < N, S_ij, -float("inf"))

                # 10. On chip, compute mtilde_ij = rowmax(S_ij)
                mtilde_ij = S_ij.max(dim=1).values  # [B_r]

                # 10. On chip, compute Ptilde_ij = exp(S_ij - mtilde_ij)
                Ptilde_ij = torch.exp(S_ij - mtilde_ij.unsqueeze(1))  # [B_r, B_c]

                # 11. On chip, compute ltilde_ij = rowsum(Ptilde_ij)
                ltilde_ij = Ptilde_ij.sum(dim=1)  # [B_r]

                # 11. On chip, compute mnew_i = max(m_i, mtilde_ij)
                mnew_i = torch.maximum(m_i, mtilde_ij)  # [B_r]

                # 11. On chip, compute lnew_i = exp(m_i - mnew_i) * l_i + exp(mtilde_ij - mnew_i) * ltilde_ij
                alpha = torch.exp(m_i - mnew_i)  # [B_r]
                beta = torch.exp(mtilde_ij - mnew_i)  # [B_r]
                lnew_i = alpha * l_i + beta * ltilde_ij  # [B_r]

                # 12. Write O_i = diag(lnew_i)^-1 (diag(l_i) exp(m_i - mnew_i) O_i + exp(mtilde_ij - mnew_i) Ptilde_ij V_j) to HBM
                P_scale = beta / lnew_i  # [B_r]
                O_scale = l_i / lnew_i * alpha  # [B_r]
                O_i = O_i * O_scale.unsqueeze(1) + (Ptilde_ij * P_scale.unsqueeze(1)) @ V_j

                # 13. Write l_i = lnew_i to HBM
                l_i = lnew_i

                # 13. Write m_i = mnew_i to HBM
                m_i = mnew_i

            O.append(O_i)  # write to HBM

        O = torch.cat(O)
        O = O[:N]  # remove padding
        return O

    # Run inner across Z, H dimensions
    O = torch.stack([torch.stack([inner(Q[z, h], K[z, h], V[z, h]) for h in range(H)]) for z in range(Z)])
    return O


Z = 6  # batch size
H = 2  # num heads
N = 8  # sequence length
D = 4  # embed dim (head dim)

# M = 128  # on-chip SRAM size
# B_c = cdiv(M, 4 * d)  # block size
# B_r = min(B_c, d)

B_c = B_r = 4


Q = torch.randn(Z, H, N, D, dtype=torch.float32, device="cuda")
K = torch.randn(Z, H, N, D, dtype=torch.float32, device="cuda")
V = torch.randn(Z, H, N, D, dtype=torch.float32, device="cuda")

output_flash = torch_flash_attention_kernel(Q, K, V, B_r=B_r, B_c=B_c)
# print(F.scaled_dot_product_attention(Q, K, V))
output_torch = torch_scaled_dot_product_attention(Q, K, V)
print(torch.allclose(output_flash, output_torch, atol=1e-6))

True


In [31]:
@triton.jit
def triton_flash_attention_kernel(
    Q_ptr,
    K_ptr,
    V_ptr,
    O_ptr,
    stride_Q0,
    stride_Q1,
    stride_K0,
    stride_K1,
    stride_V0,
    stride_V1,
    stride_O0,
    stride_O1,
    N: int,
    D: int,
    softmax_scale: float,
    B_r: tl.constexpr,
    B_c: tl.constexpr,
    B_d: tl.constexpr,
    allow_tf32: tl.constexpr = False,
):
    assert D == B_d
    i = tl.program_id(0)
    zh = tl.program_id(1)

    # 8. Load Q_i, O_i, l_i, m_i into SRAM
    Q_i_ptrs = tl.make_block_ptr(
        base=Q_ptr,
        shape=(N, D),
        strides=(stride_Q0, stride_Q1),
        offsets=(i * B_r, 0),
        block_shape=(B_r, B_d),
        order=(0, 1),
    )
    Q_i = tl.load(Q_i_ptrs)  # [B_r, D]

    O_i = tl.zeros((B_r, B_c), dtype=Q_i.dtype)  # [B_r, B_c]
    l_i = tl.zeros((B_r,), dtype=Q_i.dtype)  # [B_r]
    m_i = tl.full((B_r,), -float("inf"), dtype=Q_i.dtype)  # [B_r]

    # 3. Divide K, V into T_c blocks of size [B_c, D] each
    T_c = tl.cdiv(N, B_c)

    K_j_ptrs = tl.make_block_ptr(
        base=K_ptr,
        shape=(N, D),
        strides=(stride_K0, stride_K1),
        offsets=(0, 0),
        block_shape=(B_c, B_d),
        order=(0, 1),
    )
    V_j_ptrs = tl.make_block_ptr(
        base=V_ptr,
        shape=(N, D),
        strides=(stride_V0, stride_V1),
        offsets=(0, 0),
        block_shape=(B_c, B_d),
        order=(0, 1),
    )

    # Inner loop (NOTE: in Algorithm 1, this is the outer loop; Algorithm 1's inner loop is the outer loop here via tl.program_id(0))
    for j in range(T_c):
        # 3. Divide K, V into T_c blocks of size [B_c, D] each
        # 6. Load K_j, V_j into SRAM
        K_j = tl.load(K_j_ptrs, boundary_check=(0, 1))  # [B_c, D]
        V_j = tl.load(V_j_ptrs, boundary_check=(0, 1))  # [B_c, D]

        K_j = tl.trans(K_j)  # [B_c, D]

        # 9. On chip, compute S_ij = Q_i @ K_j^T
        S_ij = tl.dot(Q_i, K_j, allow_tf32=allow_tf32)  # [B_r, B_c] # NOTE: K_j is already loaded in its transpose

        # 9a. Scale by sqrt(d) (not in the paper, but part of the attention formula)
        S_ij = S_ij * softmax_scale

        # 9b. Mask out-of-bounds elements
        rows = j * B_c + tl.arange(0, B_c)
        S_ij = tl.where((rows[None, :] < N), S_ij, -float("inf"))

        # 10. On chip, compute mtilde_ij = rowmax(S_ij)
        mtilde_ij = tl.max(S_ij, axis=1)  # [B_r]

        # 10. On chip, compute Ptilde_ij = exp(S_ij - mtilde_ij)
        Ptilde_ij = tl.exp(S_ij - mtilde_ij[:, None])  # [B_r, B_c]

        # 11. On chip, compute ltilde_ij = rowsum(Ptilde_ij)
        ltilde_ij = tl.sum(Ptilde_ij, axis=1)  # [B_r]

        # 11. On chip, compute mnew_i = max(m_i, mtilde_ij)
        mnew_i = tl.maximum(m_i, mtilde_ij)  # [B_r]

        # 11. On chip, compute lnew_i = exp(m_i - mnew_i) * l_i + exp(mtilde_ij - mnew_i) * ltilde_ij
        alpha = tl.exp(m_i - mnew_i)  # [B_r]
        beta = tl.exp(mtilde_ij - mnew_i)  # [B_r]
        lnew_i = alpha * l_i + beta * ltilde_ij  # [B_r]

        # 12. Write O_i = diag(lnew_i)^-1 (diag(l_i) exp(m_i - mnew_i) O_i + exp(mtilde_ij - mnew_i) Ptilde_ij V_j) to HBM
        P_scale = beta / lnew_i  # [B_r]
        O_scale = l_i / lnew_i * alpha  # [B_r]
        O_i = O_i * O_scale[:, None] + tl.dot(Ptilde_ij * P_scale[:, None], V_j, allow_tf32=allow_tf32)

        # 13. Write l_i = lnew_i to HBM
        l_i = lnew_i

        # 13. Write m_i = mnew_i to HBM
        m_i = mnew_i

        # Advance block pointers to the next block
        K_j_ptrs = K_j_ptrs.advance((B_c, 0))
        V_j_ptrs = V_j_ptrs.advance((B_c, 0))

    # 12. Write O_i to HBM
    O_i_ptrs = tl.make_block_ptr(
        base=O_ptr,
        shape=(N, D),
        strides=(stride_O0, stride_O1),
        offsets=(i * B_r, 0),
        block_shape=(B_r, B_d),
        order=(0, 1),
    )
    tl.store(O_i_ptrs, O_i)


def triton_flash_attention(Q, K, V, B_r, B_c):
    N, D = Q.shape
    dtype = Q.dtype

    # 2. Initialize O, l, m in HBM
    O = torch.zeros(N, D, device=Q.device, dtype=dtype)  # [N, d]

    B_d = D

    T_r = cdiv(N, B_r)
    softmax_scale = 1.0 / D**0.5

    triton_flash_attention_kernel[(T_r, 1, 1)](
        Q,
        K,
        V,
        O,
        Q.stride(0),
        Q.stride(1),
        K.stride(0),
        K.stride(1),
        V.stride(0),
        V.stride(1),
        O.stride(0),
        O.stride(1),
        N,
        D,
        softmax_scale,
        B_r,
        B_c,
        B_d,
    )
    return O


N = 34  # batch size
D = 16  # embed dim (head dim)
M = 128  # on-chip SRAM size
B_r = B_c = 16
B_d = D

# B_c = cdiv(M, 4 * d)  # block size
# B_r = min(B_c, d)


# N = 2
# D = 16
# B_r = B_c = 16
# B_d = D


torch.manual_seed(0)
Q = torch.randn(N, D, dtype=torch.float32, device="cuda")
K = torch.randn(N, D, dtype=torch.float32, device="cuda")
V = torch.randn(N, D, dtype=torch.float32, device="cuda")

output_triton = triton_flash_attention(Q, K, V, B_r=B_r, B_c=B_c)
# print(F.scaled_dot_product_attention(Q, K, V))
# output_torch = torch_scaled_dot_product_attention(Q, K, V)
torch.cuda.synchronize()
Q, K, V = Q.unsqueeze(0).unsqueeze(0), K.unsqueeze(0).unsqueeze(0), V.unsqueeze(0).unsqueeze(0)
output_torch = torch_flash_attention_kernel(Q, K, V, B_r=B_r, B_c=B_c).squeeze(0).squeeze(0)
print(torch.allclose(output_triton, output_torch, atol=1e-6))
print((output_torch - output_triton).pow(2).mean())

True
tensor(    0.0000, device='cuda:0')
