<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Flash_Attention_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import triton
import triton.language as tl
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
@triton.jit
def _attn_fwd_inner(
          O_block,
          l_i,
          m_i,
          Q_block,
          K_block_ptr,
          V_block_ptr,
          pid_q,
          softmax_scale,
          BLOCK_SIZE_Q: tl.constexpr,
          BLOCK_SIZE_KV: tl.constexpr,
          STAGE: tl.constexpr,
          offs_q: tl.constexpr,
          offs_kv: tl.constexpr,
          SEQ_LEN: tl.constexpr,):
  if STAGE==1:
    lo, hi = 0, pid_q * BLOCK_SIZE_Q
  elif STAGE ==2 :
    lo, hi = pid_q * BLOCK_SIZE_Q, (pid_q + 1) * BLOCK_SIZE_Q
    lo = tl.multiple_of(lo,BLOCK_SIZE_Q)
  else:
    lo,hi = 0, SEQ_LEN
  K_block_ptr = tl.advance(K_block_ptr,(0,lo))
  V_block_ptr = tl.advance(V_block_ptr,(lo,0))
  for start_kv in range(lo,hi,BLOCK_SIZE_KV):
    start_kv = tl.multiple_of(start_kv,BLOCK_SIZE_KV)
    K_block = tl.load(K_block_ptr)
    QK_block = tl.dot(Q_block, K_block)
    if STAGE==2:
      mask = offs_q[:,None] >= (start_kv + offs_kv[None,:])
      QK_block = QK_block * softmax_scale + tl.where(mask,0,-1.0e6)
      m_ij = tl.maximum(m_i,tl.max(QK_BLOCK,1))
      QK_block -= m_ij[:,None]
    else:
      m_ij = tl.maximum(m_i,tl.max(QK_block,1) * sofmtax_scale)
      QK_block = QK_block * softmax_scale -m_ij[:,None]
    P_block = tl.math.exp(QK_block)
    l_ij = tl.sum(P_block,1)
    alpha = tl.math.exp(m_i - m_ij)
    l_i = l_i * alpha + l_ij
    V_block = tl.load(V_block_ptr)
    P_block = P_block.to(tl.float16)
    O_block = O_block * alpha[:,None]
    O_block = tl.dot(P_block,V_block,O_block)
    m_i = m_ij
    V_block_ptr = tl.advance(V_block_ptr,(BLOCK_SIZE_KV,0))
    K_block_ptr = tl.advance(K_block_ptr,(0,BLOCK_SIZE_KV))
  return O_block, l_i, m_i

@triton.autotune(
    [
        triton.Config(
            {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_SIZE_Q in [16, 32]
        for BLOCK_SIZE_KV in [16, 32]
        for num_stages in ([3,4,7])
        for num_warps in [2,4]
    ],
    key=["SEQ_LEN", "HEAD_DIM"],)
@triton.jit
def _attn_fwd(
    Q,K,V,
    softmax_scale,
    M,O,
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE:tl.constexpr,):
 # tl.static_assert(BLOCK_SIZE_KV<=HEAD_DIM)
    pid_q = tl.program_id(axis=0)
    block_index_q = pid_q
    index_batch_head = tl.program_id(axis=1)
    # the index_batch is updated each NUM_HEADS steps , need to pass through all the heads in the batch to jump to the next batch
    index_batch = index_batch_head // NUM_HEADS
    # use modulo because the heads index repeated for each batch , so for batch 0 [h0,h1,h2,h3] -> batch 1 [h0,h1,h2,h3] and so one
    index_head = index_batch_head % NUM_HEADS
    qvk_offset = (
        index_batch.to(tl.int64) * stride_Q_batch + index_head.to(tl.int64) * stride_Q_head)
    qkv_offset = qvk_offset
    offset_q = pid_q * BLOCK_SIZE_Q + tl.arange(0,BLOCK_SIZE_Q)
    offset_head = tl.arange(0,HEAD_DIM)
    offset_kv = tl.arange(0,BLOCK_SIZE_KV)
    # offset[:,None] convert the offset to column vector with shape (len(offset),1) and offset[None,:] convert to row vector
    Q_block_ptr = Q + qkv_offset + offset_q[:,None] * stride_Q_seq + offset_head[None,:] * stride_Q_dim
    K_block_ptr = K + qkv_offset + offset_head[:,None] * stride_K_dim + offset_kv[None,:] * stride_K_seq
    V_block_ptr = V + qkv_offset + offset_kv[:,None] * stride_V_seq + offset_head[None,:] * stride_V_dim
    O_block_ptr = O + qkv_offset + offset_q[:,None] * stride_O_seq + offset_head[None,:] * stride_O_dim

    Q_block = tl.load(Q_block_ptr)
    m_i = tl.zeros([BLOCK_SIZE_Q],dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_SIZE_Q],dtype=tl.float32) + 1.0
    O_block = tl.zeros([BLOCK_SIZE_Q,HEAD_DIM],dtype=tl.float32)
    if STAGE == 1 or STAGE == 3:
      O_block, l_i, m_i = _attn_fwd_inner(
          O_block,
          l_i,
          m_i,
          Q_block,
          K_block_ptr,
          V_block_ptr,
          pid_q,
          softmax_scale,
          BLOCK_SIZE_Q,
          BLOCK_SIZE_KV,
          4 - STAGE,
          offset_q,
          offset_kv,
          SEQ_LEN,
      )
    if STAGE ==3:
      O_block, l_i, m_i = _attn_fwd_inner(
          O_block,
          l_i,
          m_i,
          Q_block,
          K_block_ptr,
          V_block_ptr,
          block_index_q,
          softmax_scale,
          BLOCK_SIZE_Q,
          BLOCK_SIZE_KV,
          2,
          offs_q,
          offs_kv,
          SEQ_LEN,)
    m_i += tl.math.log(
        l_i)
    O_block = O_block / l_i[:,None]
    m_ptrs = M + index_batch_head * SEQ_LEN + offs_q
    tl.store(m_ptrs, m_i)
    tl.store(O_block_ptr, O_block.to(O.type.element_ty))
@triton.jit
def _attn_bwd_preprocess(
      O,
      dO,
      D, # same shape of M batch_size , nbr_heads, seq_len
      SEQ_LEN,
      BLOCK_SIZE_Q:tl.constexpr,
      HEAD_DIM: tl.constexpr,):
  block_index_q = tl.program_id(0) # which group of vector we re going to work with grid => seq_len/macro , batch*nbr_heads
  offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0,BLOCK_SIZE_Q)
  index_batch_head = tl.program_id(1)
  offs_dim = tl.arange(0,HEAD_DIM)
  # HEAD_DIM * SEQ_LEN == O.stride(1) stride of heads to jump from head to another you need to skip HEAD_DIM * SEQ_LEN elements
  # HEAD_DIM is the stride of SEQ_LEN dimension
  # here is the same configuration as in the forward pass
  O_dO_ptr =  index_batch_head * HEAD_DIM * SEQ_LEN + offs_q[:,None] * HEAD_DIM + offs_dim[None,:]
  O_block = tl.load(O_dO_ptr+ O_ptr).to(tl.float32)
  dO_block = tl.load(O_dO_ptr + dO).to(tl.float32)
  D_block = tl.sum(dO_block * O_block,axis=1)
  D_block_ptrs = D + index_batch_head * SEQ_LEN + offs_q[None,:]
  tl.store(D_block_ptrs,D_block)
@triton.jit
def _attn_bwd_dq(
    Q,
    K,
    V,
    softmax_scale,
    dO,
    dQ,
    dK,
    dV,
    M,
    D,
    stride_batch,
    stride_head,
    stride_seq,
    stride_dim,
    NUM_HEADS,
    SEQ_LEN,
    BLOCK_Q: tl.constexpr,
    BLOCK_KV: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    STAGE: tl.constexpr,):
    # same configuration of dk dv backward , read it first
    index_batch_head = tl.program_id(2)
    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS
    offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to(
        tl.int64
    )
    offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64)

    Q += offset_batch_head
    K += offset_batch_head
    V += offset_batch_head
    dO += offset_batch_head
    dQ += offset_batch_head
    dK += offset_batch_head
    dV += offset_batch_head

    M += offset_batch_head_seq
    D += offset_batch_head_seq

    offs_dim = tl.arange(0, HEAD_DIM)

    index_block_kv = tl.program_id(0)

    start_q = index_block_kv * BLOCK_Q
    offs_q = start_q + tl.arange(0, BLOCK_Q)

    Q_block = tl.load(Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim)
    dQ_block = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32)
    dO_block = tl.load(
        dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim
    )

    M_block = tl.load(M + offs_q)
    M_block = M_block[:, None]

    offs_kv = tl.arange(0, BLOCK_KV)

    # We access the K and V as transposed blocks
    kT_ptrs = K + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim
    vT_ptrs = V + offs_kv[None, :] * stride_seq + offs_dim[:, None] * stride_dim

    Di = tl.load(D + offs_q)

    curr_kv = 0
    num_steps = SEQ_LEN // BLOCK_KV
    for blk_idx in range(num_steps):
        K_T_block = tl.load(kT_ptrs)
        V_T_block = tl.load(vT_ptrs)
        QK_block = softmax_scale * tl.dot(Q_block, K_T_block)
        P_block = tl.math.exp(QK_block - M_block)

        if STAGE == 3:
            # mask the values that are above the diagonal ( causal masking )
            offs_kv = curr_kv + tl.arange(0, BLOCK_KV)
            mask_block = offs_q[:, None] >= offs_kv[None, :]
            P_block = tl.where(mask_block, P_block, 0.0)

        # follow the algorithm by compute dP and dS in order to compute dQ
        dP_block = tl.dot(dO_block, V_T_block).to(tl.float32)
        dS_block = P_block * (dP_block - Di[:, None])
        dS_block = dS_block.to(tl.float16)
        dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block)) # taw in flash attention algorithm is the softmax_scale
        # jump pointers by BLOCK_KV * stride_seq (stride_seq is equal to head_dim because to jump from sequence to sequence need to pass through all head_dim values).
        curr_kv += BLOCK_KV
        kT_ptrs += BLOCK_KV * stride_seq
        vT_ptrs += BLOCK_KV * stride_seq

    dQ_block_ptrs = dQ + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim
    tl.store(dQ_block_ptrs, dQ_block)

@triton.jit
def _attn_bwd_dk_dv(
      Q,
      K,
      V,
      softmax_scale,
      dO,
      dQ,
      dK,
      dV,
      M,
      D,
      stride_batch,
      stride_head,
      stride_seq,
      stride_dim,
      NUM_HEADS,
      SEQ_LEN,
      BLOCK_Q: tl.constexpr,
      BLOCK_KV: tl.constexpr,
      HEAD_DIM: tl.constexpr,
      STAGE: tl.constexpr):
    index_batch_head = tl.program_id(2)
    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS
    offset_batch_head = (stride_batch * index_batch + stride_heads * index_head).to(tl.int64)
    offset_batch_head_seq = ( index_batch_head * SEQ_LEN).to(tl.int64) # equivalent to index_batch * SEQ_LEN * NUM_HEADS + index_head * SEQ_LEN
    offs_dim = tl.arange(0,HEAD_DIM)
    index_block_kv = tl.program_id(0) # since we fix the macro which deal with kv block , fixed kv and loop through others (following flash attention 2 algorithm)
    start_kv = index_block_kv * BLOCK_KV
    offs_kv = start_kv + tl.arange(0,BLOCK_KV)
    Q += offset_batch_head
    K += offset_batch_head
    V += offset_batch_head
    dO += offset_batch_head
    dQ += offset_batch_head
    dK += offset_batch_head
    dV += offset_batch_head
    M += offset_batch_head
    D += offset_batch_head
    dV_block = tl.zeros([BLOCK_KV,HEAD_DIM],dtype=tl.float32)
    dK_block = tl.zeros([BLOCK_KV,HEAD_DIM],dtype=tl.float32)
    K_block = tl.load(K + offs_kv[:,None] * stride_seq + offs_dim[None,:] * stride_dim)
    V_block = tl.load(V + offs_kv[:,None] * stride_seq + offs_dim[None,:] * stride_dim)
    offs_q = tl.arange(0,BLOCK_Q)
    qT_ptrs = Q + offs_q[None,:] * stride_seq + offs_dim[:,None] * stride_dim
    dO_ptrs = dO + offs_q[:,None] * stride_seq + offs_dim[None,:] * stride_dim
    curr_q = 0
    num_steps = SEQ_LEN // BLOCK_Q
    for blk_idx in range(num_steps):
      qT_block = tl.load(qT_ptrs)
      offs_q = curr_q + tl.arange(0, BLOCK_Q)
      m = tl.load(M + offs_q)
      QK_T_block = softmax_scale * tl.dot(K_block,qT_block)
      P_T_block = tl.math.exp(QK_T_block - m[None,:])

      if STAGE == 3:
        mask_block = (
            offs_q[None,:] >= offs_kv[:,None]) # since we are working with Q_T the seq_len becomes in cols axis and KV has seq_len in rows axis
            # create 2d tensor mask to enable causal masking ( mask the values that are above the diagonal )
        P_T_block = tl.where(mask_block, P_T_block,0.0)
      dO_block = tl.load(dO_ptrs)
      dV_block += tl.dot(P_T_block.to(tl.float16),dO_block)
      Di = tl.load(D + offs_q)
      dpT_block = tl.dot(V_block,tl.trans(dO_block)).to(tl.float32)
      dS_T_block = P_T_block * (dpT_block - Di[None,:])
      dS_T_block = dS_T_block.to(tl.float16)
      dK_block += softmax_scale * tl.dot(dS_T_block,tl.trans(qT_block))
      curr_q += BLOCK_Q
      qT_ptrs += BLOCK_Q * stride_seq
      dO_ptrs += BLOCK_Q * stride_seq
    dV_block_ptrs = dV + offs_kv[:,None] * stride_seq + offs_dim[None,:] * stride_dim
    tl.store(dV_block_ptrs,dV_block)
    dK_block_ptrs = dK + offs_kv[:,None] * stride_seq + offs_dim[None,:] * stride_dim
    tl.store(dK_block_ptrs,dK_block)

In [4]:
class TritonAttention(torch.autograd.Function):
  @staticmethod
  def forward(ctx,Q,K,V,causal,softmax_scale):
    HEAD_DIM_Q,HEAD_DIM_K = Q.shape[-1],K.shape[-1]
    HEAD_DIM_V = V.shape[-1]
    BATCH_SIZE,NUM_HEADS,SEQ_LEN,HEAD_DIM = Q.shape
    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_K
    O = torch.empty_like(Q)
    stage = 3 if causal else 1
    grid = lambda args : (
        triton.cdiv(SEQ_LEN,args['BLOCK_SIZE_Q']),
        BATCH_SIZE * NUM_HEADS,
        1,)
    M = torch.empty((BATCH_SIZE,NUM_HEADS,SEQ_LEN),device=Q.device,dtype=torch.float32)
    _attn_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=softmax_scale,
            M=M,
            O=O,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_dim=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_dim=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_dim=V.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_dim=O.stride(3),
            BATCH_SIZE=Q.shape[0],
            NUM_HEADS=Q.shape[1],
            SEQ_LEN=Q.shape[2],
            HEAD_DIM=HEAD_DIM_K,
            STAGE=stage,
        )
    ctx.save_for_backward(Q,K,V,O,M)
    ctx.grid = grid
    ctx.softmax_scale = sftmax_scale
    ctx.HEAD_DIM = HEAD_DIM_K
    ctx.causal = causal
    return O
  @staticmethod
  def backward(ctx, dO):
    Q, K, V, O, M = ctx.saved_tensors
    assert dO.is_contiguous()
    assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
    dQ = torch.empty_like(Q)
    dK = torch.empty_like(K)
    dV = torch.empty_like(V)
    BATCH_SIZE, NUM_HEADS, SEQ_LEN = Q.shape[:3]
    NUM_WARPS, NUM_STAGES = 4, 3
    BLOCK_SIZE_MICRO, BLOCK_SIZE_MACRO = 32, 128
    preprocess_grid = (SEQ_LEN // BLOCK_SIZE_MACRO, BATCH_SIZE * NUM_HEADS)
    D = torch.empty_like(M)
    _attn_bwd_preprocess[preprocess_grid](
        O=O,
        dO=dO,
        D=D,
        SEQ_LEN=SEQ_LEN,
        BLOCK_SIZE_Q=BLOCK_SIZE_MACRO,
        HEAD_DIM=ctx.HEAD_DIM,)
    grid = (SEQ_LEN // BLOCK_SIZE_MACRO, 1, BATCH_SIZE * NUM_HEADS)
    stage = 3 if ctx.causal else 1
    _attn_bwd_dk_dv[grid](
        Q=Q,
        K=K,
        V=V,
        softmax_scale=ctx.softmax_scale,
        dO=dO,
        dQ=dQ,
        dK=dK,
        dV=dV,
        M=M,
        D=D,
        stride_batch=Q.stride(0),
        stride_head=Q.stride(1),
        stride_seq=Q.stride(2),
        stride_dim=Q.stride(3),
        NUM_HEADS=NUM_HEADS,
        SEQ_LEN=SEQ_LEN,
        BLOCK_Q=BLOCK_SIZE_MICRO,
        BLOCK_KV=BLOCK_SIZE_MACRO,
        HEAD_DIM=ctx.HEAD_DIM,
        STAGE=stage,
        num_warps=NUM_WARPS,
        num_stages=NUM_STAGES,)
    _attn_bwd_dq[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=ctx.softmax_scale,
            dO=dO,
            dQ=dQ,
            dK=dK,
            dV=dV,
            M=M,
            D=D,
            stride_batch=Q.stride(0),
            stride_head=Q.stride(1),
            stride_seq=Q.stride(2),
            stride_dim=Q.stride(3),
            NUM_HEADS=NUM_HEADS,
            SEQ_LEN=SEQ_LEN,
            BLOCK_Q=BLOCK_SIZE_MACRO,
            BLOCK_KV=BLOCK_SIZE_MICRO,
            HEAD_DIM=ctx.HEAD_DIM,
            STAGE=stage,
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        )
    return dQ, dK, dV, None, None



In [None]:
def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, causal, dtype=torch.float16):
    Q = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )
    K = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )
    V = (
        torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), dtype=dtype, device="cuda"
        )
        .normal_(mean=0.0, std=0.5)
        .requires_grad_()
    )

    softmax_scale = 1 / (HEAD_DIM**0.5)
    dO = torch.randn_like(Q)

    # reference implementation
    MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device="cuda"))
    P = torch.matmul(Q, K.transpose(2, 3)) * softmax_scale
    if causal:
        P[:, :, MASK == 0] = float("-inf")
    P = torch.softmax(P.float(), dim=-1).half()
    ref_O = torch.matmul(P, V)
    ref_O.backward(dO)
    ref_dV, V.grad = V.grad.clone(), None
    ref_dK, K.grad = K.grad.clone(), None
    ref_dQ, Q.grad = Q.grad.clone(), None

    # triton implementation
    tri_out = TritonAttention.apply(Q, K, V, causal, softmax_scale).half()
    tri_out.backward(dO)
    tri_dV, V.grad = V.grad.clone(), None
    tri_dK, K.grad = K.grad.clone(), None
    tri_dQ, Q.grad = Q.grad.clone(), None

    # compare
    rtol = 0.0
    atol = 1e-2
    assert torch.allclose(ref_O, tri_out, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dK, tri_dK, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dV, tri_dV, atol=atol, rtol=rtol)
    assert torch.allclose(ref_dQ, tri_dQ, atol=atol, rtol=rtol)


if __name__ == "__main__":
    test_op(BATCH_SIZE=2, NUM_HEADS=4, SEQ_LEN=16, HEAD_DIM=32, causal=True)
    print("GPU Memory Allocated:", torch.cuda.memory_allocated())
    print("GPU Memory Reserved:", torch.cuda.memory_reserved())
    #test_op(BATCH_SIZE=1, NUM_HEADS=1, SEQ_LEN=2, HEAD_DIM=4, causal=False)
    print("PASSED")

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
