**Forward Pass**

In [2]:
# dependencies
import torch
import triton
import triton.language as tl


In [5]:
@triton.jit
def _attn_fwd_inner(
    O_block,
    l_i,
    m_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    block_index_q,
    softmax_scale,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    ):
  # range of values handled by this stage
  if STAGE == 1:
    # From 0 to the left of the diagonal (in the notebook)
    lo, hi = 0, block_index_q * BLOCK_SIZE_Q
  elif STAGE == 2:
    # Used only for the block in which there is transition between non-masked and masked keys
    lo, hi = block_index_q * BLOCK_SIZE_Q, (block_index_q + 1) * BLOCK_SIZE_Q
    lo = tl.multiple_of(lo, BLOCK_SIZE_Q)
  else:
    # Only used for non-causal attention
    lo, hi = 0, SEQ_LEN

  # point it to first K, V block
  K_block_ptr = tl.advance(K_block_ptr, (0, lo))
  V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

  # loop over k, v and update accumulator
  for start_kv in range(lo, hi, BLOCK_SIZE_KV):
    # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations
    start_kv = tl.multiple_of(start_kv, BLOCK_SIZE_KV)

    # -- compute qk ----
    K_block = tl.load(K_block_ptr)
    QK_block = tl.dot(Q_block, K_block)

    if STAGE == 2:
      mask = offs_q[:, None] >= (start_kv + offs_kv[None, :])
      QK_block = QK_block * softmax_scale + tl.where(mask, 0, -1.0e6)
      m_ij = tl.maximum(m_i, tl.max(QK_block, 1))
      QK_block -= m_ij[:, None]
    else:
      # Compute the maximum value of qk or keep the old max value
      m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale)
      QK_block = QK_block * softmax_scale - m_ij[:, None]
    # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij)
    P_block = tl.math.exp(QK_block)

    # Compute the sum by rows of the attention scores
    l_ij = tl.sum(P_block, 1)

    # This is the correction factor for the previous l_i
    alpha = tl.math.exp(m_i - m_ij)

    # Apply the correction factor to the previous l_i and add the new l_ij
    l_i = l_i * alpha + l_ij

    V_block = tl.load(V_block_ptr)
    P_block = P_block.to(tl.float16)

    # This computes the following: O_new = P x V + O_old * alpha
    O_block = O_block * alpha[:, None]
    O_block = tl.dot(P_block, V_block, O_block) # O_block += P_block @ V_block

    m_i = m_ij

    V_block_ptr = tl.advance(V_block_ptr, (BLOCK_SIZE_KV, 0)) # V[SEQ_LEN, HEAD_DIM]
    K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_SIZE_KV)) # K[HEAD_DIM, SEQ_LEN]

  return O_block, l_i, m_i

In [6]:
# Triton kernel signature - makes the python method a triton kernel
@triton.jit
def _attn_fwd(
    Q,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM # This is a pointer but we need to do something like Q[index_batch, index_head, :, :]
    # done by qvk_offset
    K,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    V,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    softmax_scale,
    M,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN
    O,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    STAGE: tl.constexpr,
):
  tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)

  # This indicate which block in the sequence length to process
  block_index_q = tl.program_id(0)

  # This indicates which head and batch to process. Each program is associated with a single head of a single batch
  index_batch_head = tl.program_id(1)
  # This indicate which batch this program is associated with (each batch has NUM_HEADS heads)
  index_batch = index_batch_head // NUM_HEADS
  # This indicate the position of the head in the batch
  index_head = index_batch_head % NUM_HEADS

  qvk_offset = (
      index_batch.to(tl.int64) * stride_Q_batch
      + index_head.to(tl.int64) * stride_Q_head
  )

  # Make block pointer - takes a ptr # Q[index_batch, index_batch, block_index_q * BLOCK_SIZE_Q, :]
  Q_block_ptr = tl.make_block_ptr(
      base=Q + qvk_offset,
      shape=(SEQ_LEN, HEAD_DIM),
      strides=(stride_Q_seq, stride_Q_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1, 0),
  )

  # We are not skipping anything on sequence, HEAD_DIIM
  V_block_ptr = tl.make_block_ptr( # V[index_batch, index_head, :, :]
      base=V + qvk_offset,
      shape=(SEQ_LEN, HEAD_DIM),
      strides=(stride_V_seq, stride_V_dim),
      offsets=(0, 0),
      block_shape=(BLOCK_SIZE_KV, HEAD_DIM),
      order=(1, 0),
  )

  K_block_ptr = tl.make_block_ptr( # K[index_batch, index_head, :, :]
      base=K + qvk_offset,
      shape=(HEAD_DIM, SEQ_LEN),
      strides=(
          stride_K_dim,
          stride_K_seq,
      ),  # We invert the strides w.r.t Q, so we transpose the matrix
      offsets=(0, 0),
      block_shape=(HEAD_DIM, BLOCK_SIZE_KV),
      order=(0, 1),
  )

  O_block_ptr = tl.make_block_ptr( # O[index_batch, index_batch, block_index_q * BLOCK_SIZE_Q, :]
      base=O + qvk_offset,
      shape=(SEQ_LEN, HEAD_DIM),
      strides=(stride_O_seq, stride_O_dim),
      offsets=(block_index_q * BLOCK_SIZE_Q, 0),
      block_shape=(BLOCK_SIZE_Q, HEAD_DIM),
      order=(1, 0),
  )
  # offs_q: the offsets for the tokens in the Q to process
  offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)

  # offs_kv: the offsets for the tokens in the K and V sequence to process
  offs_kv = tl.arange(0, BLOCK_SIZE_KV)

  # m_i: the running maximum. We have one for each query
  m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf")

  # l_i: the running sum. We have one for each query (as we sum the attention scores by rows)
  l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0 # 1.0 to make it stable

  # acc: the accumulator for the output, which is a group of rows of the O matrix
  O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32)

  # load it from HBM to shared memory
  Q_block = tl.load(Q_block_ptr)


  # Causal Attention: We dont want Q to attend K that come after it
  if STAGE == 1 or STAGE == 3:
    # This step runs for non-causal attention or for the blocks to the left of the diagonal in the causal attention
    # _attn_fwd_inner this inner loop needs to go trhough all K, V blocks
    # For each K, V it needs to fix previous calculated block
    O_block, l_i, m_i = _attn_fwd_inner(
          O_block,
          l_i,
          m_i,
          Q_block,
          K_block_ptr,
          V_block_ptr,
          block_index_q,
          softmax_scale,
          BLOCK_SIZE_Q,
          BLOCK_SIZE_KV,
          4 - STAGE,
          offs_q,
          offs_kv,
          SEQ_LEN,
      )
  if STAGE == 3:
    # This step runs for the blocks to the right of the diagonal in the causal attention
    O_block, l_i, m_i = _attn_fwd_inner(
          O_block,
          l_i,
          m_i,
          Q_block,
          K_block_ptr,
          V_block_ptr,
          block_index_q,
          softmax_scale,
          BLOCK_SIZE_Q,
          BLOCK_SIZE_KV,
          2,
          offs_q,
          offs_kv,
          SEQ_LEN,
      )

  m_i += tl.math.log(
      l_i
  ) # This is needed to compute the logsumexp for the backwards pass
  O_block = O_block / l_i[:, None]

  # For each batch and each sequence length
  m_ptrs = M + index_batch_head * SEQ_LEN + offs_q

In [7]:
class TritonAttention(torch.autograd.Function):

  @staticmethod
  # ctx = context storage for backward which is stored while forward
  def forward(ctx, Q, K, V, causal, softmax_scale):
    HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
    HEAD_DIM_V = V.shape[-1]

    BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape
    assert HEAD_DIM == HEAD_DIM_Q and HEAD_DIM_K == HEAD_DIM_V

    O = torch.empty_like(Q)
    stage = 3 if causal else 1

    # launch grid (tile) - how many parallel process to be launched
    grid = lambda args: (
        # ceil(SEQ_LEN / BLOCK_SIZE_Q)
        triton.cdiv(SEQ_LEN, args["BLOCK_SIZE"]), # which group of queries we are going to work with (how many tiles along the sequence)
        BATCH_SIZE * NUM_HEADS, # which head of which batch element we are going to work with? (one “row” per (batch, head) pair)
        1,
    )

    # M is the logsumexp for backward pass, one for each query
    M = torch.empty(
        (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
    )

    # Triton Kernel
    _attn_fwd[grid](
            Q=Q, # just starting pointer
            K=K,
            V=V,
            softmax_scale=softmax_scale,
            M=M,
            O=O,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_dim=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_dim=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_dim=V.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_dim=O.stride(3),
            BATCH_SIZE=Q.shape[0],
            NUM_HEADS=Q.shape[1],
            SEQ_LEN=Q.shape[2],
            HEAD_DIM=HEAD_DIM_K,
            STAGE=stage,
        )

    ctx.save_for_backward(Q, K, V, O, M)
    ctx.grid = grid
    ctx.softmax_scale = softmax_scale
    ctx.HEAD_DIM = HEAD_DIM_K
    ctx.causal = causal
    return O


In [8]:
# initialise query, key and value sequence
def test_op(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, causal, dtype=torch.float16):
  Q = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM),
          dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  K = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM),
          dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )
  V = (
      torch.empty(
          (BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM),
          dtype=dtype, device="cuda"
      )
      .normal_(mean=0.0, std=0.5)
      .requires_grad_()
  )

In [9]:
# scaling factor
# QK^t/sqrt(HEAD_DIM)
softmax_scale = 1 / (HEAD_DIM ** 0.5)
d0 = torch.randn_like(Q) # for backward

NameError: name 'HEAD_DIM' is not defined

In [None]:
# Naive Attention
MASK = torch.tril(torch.ones((SEQ_LEN, SEQ_LEN), device = "cuda"))
P = torch.matmul(Q, K.transpose(2, 3)) * softmax_scale # swaps dim 2 with dim 3
if causal:
  P[:, :, MASK == 0] = float("-inf")
P = torch.softmax(P.float(), dim = -1).half()

ref_0 = torch.matmul(P, V)
ref_0.backward(d0)
ref_dV, V.grad = V.grad.clone(), None
ref_dK, K.grad = K.grad.clone(), None
ref_dQ, Q.grad = Q.grad.clone(), None

# Triton Implementation
tri_out = TritonAttention.apply(Q, K, V, causal, softmax_scale).half()
tri_out.backward(d0)
tri_dV, V.grad = V.grad.clone(), None
tri_dK, K.grad = K.grad.clone(), None
tri_dQ, Q.grad = Q.grad.clone(), None

# compare
rtol = 0.0 # relative tolerance
atol = 1e-2 # absolute tolerance
assert torch.allclose(ref_0, tri_out, rtol=rtol, atol=atol)
assert torch.allclose(ref_dV, tri_dV, rtol=rtol, atol=atol)
assert torch.allclose(ref_dK, tri_dK, rtol=rtol, atol=atol)
assert torch.allclose(ref_dQ, tri_dQ, rtol=rtol, atol=atol)