In [65]:
import os

os.environ["TRITON_INTERPRETED"] = "1"

import torch
from torch.nn import functional as F
import triton
import triton.language as tl
from triton_utils import test_pid_conds, breakpoint_if, print_if, check_tensors_gpu_ready, cdiv
from typing import Optional

In [125]:
def torch_scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    """Reference implementation of scaled dot product attention using torch operations.

    Args:
        q: Queries tensor of shape [B, H, S, D_k]
        k: Keys tensor of shape [B, H, S, D_k]
        v: Values tensor of shape [B, H, S, D_v]

    Returns:
        values: The output of the attention mechanism of shape [B, H, S, D_v]
        attn_logits: The attention logits of shape [B, H, S, S]
        attention: The attention weights of shape [B, H, S, S]

    Shapes:
        B: batch size
        H: number of heads
        S: sequence length
        D_k: key dimension
        D_v: value dimension
    """
    d_k = q.shape[-1]
    attn_logits = q @ k.transpose(-2, -1)  # [batch_size, num_heads, seq_len, seq_len]
    attn_logits = attn_logits / d_k**0.5
    attention = F.softmax(attn_logits, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
    values = attention @ v  # [batch_size, num_heads, seq_len, embed_dim]
    return values

In [126]:
def pad_on_dim(tensor, size, dim, value=0):
    """Pads a tensor with zeros on a given dimension.

    Args:
        tensor: The input tensor
        size: The size that the dimension should be padded to
        dim: The dimension to pad
        value: The value to pad with

    Returns:
        The padded tensor
    """
    pad_shape = list(tensor.shape)
    pad_shape[dim] = size - pad_shape[dim]
    # print(f"{tensor.shape=}, {pad_shape=}, {dim=}, {size=}")
    assert pad_shape[dim] >= 0
    return torch.cat([tensor, torch.full(pad_shape, value, dtype=tensor.dtype, device=tensor.device)], dim=dim)


x = torch.randn(1, 1, 4, 4, dtype=torch.float32, device="cuda")
pad_on_dim(x, 7, 2).shape

torch.Size([1, 1, 7, 4])

## Pytorch implementation of FlashAttention
I follow the notation of Algorithm 1: https://arxiv.org/pdf/2205.14135.pdf

In [147]:
@torch.no_grad()
def torch_flash_attention_kernel(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, B_r: int, B_c: int):
    """Flash attention kernel implementation using torch operations.

    This implementation closely follows Algorithm 1 in the FlashAttention paper: https://arxiv.org/pdf/2205.14135.pdf.
    Unlike Algorithm 1, this implementation performs the scaling by sqrt(d) as part of the attention formula.

    Args:
        Q: Queries tensor of shape [N, d]
        K: Keys tensor of shape [N, d]
        V: Values tensor of shape [N, d]
        B_r: The block size for the rows
        B_c: The block size for the columns
    """

    N, d = Q.shape
    dtype = Q.dtype

    # 2. Initialize O, l, m in HBM
    O = torch.zeros(N, d, device=Q.device, dtype=dtype)  # [N, d]
    l = torch.zeros(N, device=Q.device, dtype=dtype)  # [N]
    m = torch.full((N,), float("-inf"), device=Q.device, dtype=dtype)  # [N]

    # 3. Divide Q into T_r blocks of size [B_r, d] each
    T_r = cdiv(N, B_r)
    Q = list(torch.split(Q, B_r))  # [T_r, B_r, d]

    # 3. Divide K, V into T_c blocks of size [B_c, d] each
    T_c = cdiv(N, B_c)
    K = list(torch.split(K, B_c))  # [T_c, B_c, d]
    V = list(torch.split(V, B_c))  # [T_c, B_c, d]

    # 4. Divide O into T_r blocks of size [B_r, d] each
    O = list(torch.split(O, B_r))  # [T_r, B_r, d]

    # 4. Divide l into T_r blocks of size [B_r] each
    l = list(torch.split(l, B_r))  # [T_r, B_r]

    # 4. Divide m into T_r blocks of size [B_r] each
    m = list(torch.split(m, B_r))  # [T_r, B_r]

    # 5. Outer loop
    for j in range(T_c):
        # 6. Load K_j, V_j into SRAM
        K_j = K[j]  # [B_c, d]
        V_j = V[j]  # [B_c, d]

        # 7. Inner loop
        for i in range(T_r):
            # 8. Load Q_i, O_i, l_i, m_i into SRAM
            Q_i = Q[i]  # [B_r, d]
            O_i = O[i]  # [B_r, d]
            l_i = l[i]  # [B_r]
            m_i = m[i]  # [B_r]

            # 9. On chip, compute S_ij = Q_i @ K_j^T
            S_ij = Q_i @ K_j.T  # [B_r, B_c]

            # 9a. Scale by sqrt(d) (not in the paper, but part of the attention formula)
            S_ij = S_ij / (d**0.5)

            # 10. On chip, compute mtilde_ij = rowmax(S_ij)
            mtilde_ij = S_ij.max(dim=1).values  # [B_r]

            # 10. On chip, compute Ptilde_ij = exp(S_ij - mtilde_ij)
            Ptilde_ij = torch.exp(S_ij - mtilde_ij.unsqueeze(1))  # [B_r, B_c]

            # 11. On chip, compute ltilde_ij = rowsum(Ptilde_ij)
            ltilde_ij = Ptilde_ij.sum(dim=1)  # [B_r]

            # 11. On chip, compute mnew_i = max(m_i, mtilde_ij)
            mnew_i = torch.maximum(m_i, mtilde_ij)  # [B_r]

            # 11. On chip, compute lnew_i = exp(m_i - mnew_i) * l_i + exp(mtilde_ij - mnew_i) * ltilde_ij
            lnew_i = torch.exp(m_i - mnew_i) * l_i + torch.exp(mtilde_ij - mnew_i) * ltilde_ij  # [B_r]

            # 12. Write O_i = diag(lnew_i)^-1 (diag(l_i) exp(m_i - mnew_i) O_i + exp(mtilde_ij - mnew_i) Ptilde_ij V_j) to HBM
            #           O_i = a @ (b + c)
            #           where:
            #             a = diag(lnew_i) ** -1                             [B_r, B_r]
            #             b = diag(l_i) * exp(m_i - mnew_i) @ O_i            [B_r, d]
            #                 [B_r, B_r]  [B_r]               [B_r, d]
            #             c = exp(mtilde_ij - mnew_i) * Ptilde_ij @ V_j      [B_r, d]
            #                 [B_r]                     [B_r, B_c]  [B_c, d]
            _a = torch.diag(lnew_i**-1)  # [B_r, B_r]
            _b = torch.diag(l_i) * torch.exp(m_i - mnew_i) @ O_i  # [B_r, d]
            _c = torch.exp(mtilde_ij - mnew_i).unsqueeze(1) * Ptilde_ij @ V_j
            O_i = _a @ (_b + _c)
            O[i] = O_i  # write to HBM

            # 13. Write l_i = lnew_i to HBM
            l_i = lnew_i
            l[i] = l_i  # write to HBM

            # 13. Write m_i = mnew_i to HBM
            m_i = mnew_i
            m[i] = m_i  # write to HBM

    O = torch.cat(O)
    return O


N = 6  # batch size
d = 4  # embed dim (head dim)
M = 128  # on-chip SRAM size

B_c = cdiv(M, 4 * d)  # block size
B_r = min(B_c, d)

B_c = 5
B_r = 5

Q = torch.randn(N, d, dtype=torch.float32, device="cuda")
K = torch.randn(N, d, dtype=torch.float32, device="cuda")
V = torch.randn(N, d, dtype=torch.float32, device="cuda")

output_flash = torch_flash_attention_kernel(Q, K, V, B_r, B_c)
# print(F.scaled_dot_product_attention(Q, K, V))
output_torch = torch_scaled_dot_product_attention(Q, K, V)
print(torch.allclose(output_flash, output_torch, atol=1e-6))

True


In [3]:
# Input data
B = 2  # batch size
S = 4  # sequence length
H = 2  # number of heads
D_k = 8  # key embed dimension
D_v = 16  # value embed dimension

# Query, key, value
q = torch.rand(B, H, S, D_k)
k = torch.rand(B, H, S, D_k)
v = torch.rand(B, H, S, D_v)

In [4]:
# Attention
# output_torch = F.scaled_dot_product_attention(q, k, v)
output_torch, *_ = torch_scaled_dot_product_attention(q, k, v)
output_triton = scaled_dot_product_attention(q, k, v)

# Compare
print("Shapes equal:", output_torch.shape == output_triton.shape)
print("Output equal:", torch.allclose(output_torch, output_triton))

Shapes equal: True
Output equal: True
