In [None]:
import triton
import triton.language as tl
import torch
import math

@triton.jit
def kernel(q_ptr, k_ptr, v_ptr, o_ptr,
           head_qty:tl.constexpr, seq_len:tl.constexpr, head_size:tl.constexpr,
           softmax_scale:tl.constexpr, num_blocks:tl.constexpr, block_size:tl.constexpr):
    head_i = tl.program_id(0)
    block_i = tl.program_id(1)

    qo_cols = tl.arange(0, head_size)
    qo_rows = block_size * block_i + tl.arange(0, block_size)[:, None]
    qo_ofs = head_i * head_size * seq_len + qo_rows * head_size + qo_cols
    qo_mask = (qo_rows < seq_len) & (qo_cols < head_size)

    q_ofs = q_ptr + qo_ofs
    q_block = tl.load(q_ofs, mask = qo_mask)

    o_block = tl.zeros_like(q_block)
    m_i = tl.zeros((block_size,), dtype = tl.float32) - float('inf')
    l_i = tl.zeros((block_size,), dtype = tl.float32)

    for block_j in range(0, num_blocks):
      kv_cols = tl.arange(0, head_size)
      kv_rows = block_size * block_j + tl.arange(0, block_size)[:, None]
      kv_ofs = head_i * head_size * seq_len + kv_rows * head_size + kv_cols
      kv_mask = (kv_rows < seq_len) & (kv_cols < head_size)

      k_ofs = k_ptr + kv_ofs
      k_block = tl.load(k_ofs, mask = kv_mask)

      v_ofs = v_ptr + kv_ofs
      v_block = tl.load(v_ofs, mask = kv_mask)

      qk_block = tl.dot(q_block, k_block.T, allow_tf32 = False) * softmax_scale

      m_ij = tl.maximum(m_i, tl.max(qk_block, 1))

      qk_block -= m_ij[:, None]

      p_block = tl.exp(qk_block)

      l_ij = tl.sum(p_block, 1)
      alpha = tl.exp(m_i - m_ij)
      l_i = alpha * l_i + l_ij

      o_block *= alpha[:, None]
      o_block += tl.dot(p_block, v_block, allow_tf32 = False)

      m_i = m_ij

    o_block /= l_i[:, None]
    o_ofs = o_ptr + qo_ofs
    tl.store(o_ofs, o_block, mask = qo_mask)


def attention(q, k, v, head_qty, seq_len, head_size, softmax_scale):
  block_size = 16
  num_blocks = math.ceil(seq_len / block_size)
  o = torch.zeros_like(q)
  grid = (head_qty, num_blocks)
  kernel[grid](q, k, v, o, head_qty, seq_len, head_size, softmax_scale, num_blocks, block_size)
  return o

head_qty = 4
seq_len = 128
head_size = 64

softmax_scale = 1.0 / math.sqrt(head_size)
q = torch.randn((head_qty, seq_len, head_size), device = 'cuda', dtype = torch.float32)
k = torch.randn((head_qty, seq_len, head_size), device = 'cuda', dtype = torch.float32)
v = torch.randn((head_qty, seq_len, head_size), device = 'cuda', dtype = torch.float32)

output = attention(q, k, v, head_qty, seq_len, head_size, softmax_scale)

compare = torch.nn.functional.scaled_dot_product_attention(q, k, v)

test = torch.allclose(output, compare, atol = 1e-6, rtol = 0)

print(output)
print(compare)

if(test):
  print("Test passed")
else:
  print("Test failed")

