# Triton Attention Systems: Naive ‚Üí Page ‚Üí Flash

üéØ **Weekly Goal**  
Implement attention from scratch in Triton, profile performance bottlenecks,  
understand KV cache memory layouts (PagedAttention), and build a mini FlashAttention kernel  
to develop intuition for **IO-awareness, tiling, SRAM reuse, and kernel fusion**.

---

# Day 2 ‚Äî Naive Triton Attention

## Objective

Implement the most straightforward attention pipeline:

attn = softmax(QK·µÄ) @ V

Each stage must be implemented as an independent Triton kernel.  
‚ö†Ô∏è No fusion. No tiling optimization. No IO reduction tricks.

---

## Tasks

- [ ] Implement QK·µÄ kernel
- [ ] Implement row-wise softmax kernel
- [ ] Implement P @ V kernel
- [ ] Add optional mask support (causal / padding)
- [ ] Validate correctness vs PyTorch reference
- [ ] Measure max / mean absolute error
- [ ] Test small and large sequence lengths

---

## Key Concepts

- Attention compute complexity: O(n¬≤d)
- Memory traffic complexity: O(n¬≤)
- Materializing the attention matrix is expensive
- Softmax requires multiple passes:
  - max reduction
  - exp + sum
  - normalization

---

## Deliverables

- triton_naive_attention.py
- Correctness validation script
- Basic latency benchmark (ms)

---

# Day 3 ‚Äî Profiling & Bottleneck Analysis

## Objective

Diagnose why naive attention is slow using Nsight Compute.

---

## Tasks

- [ ] Profile kernels with Nsight Compute
- [ ] Collect:
  - DRAM throughput
  - SM efficiency
  - Achieved occupancy
  - Warp stall reasons
- [ ] Identify whether bottleneck is:
  - memory-bound
  - reduction-bound
  - compute-bound
- [ ] Sweep:
  - block sizes
  - sequence length (512 ‚Üí 4k ‚Üí 8k)
  - fp16 vs fp32

---

## Key Concepts

- Softmax is typically memory-bound
- QK·µÄ behaves like GEMM (often compute-bound)
- Writing and rereading n¬≤ matrices dominates IO
- Arithmetic intensity determines roofline behavior

---

## Deliverables

### Performance Table

| Impl  | ms | GB/s | TFLOPs | Speedup |
|-------|----|------|--------|---------|
| Torch |    |      |        | 1.0x    |
| Naive |    |      |        |         |

### Bottleneck Analysis Writeup

Explain:

- Why softmax is IO-heavy  
- Why n¬≤ memory traffic dominates  
- What stall reason dominates  
- Whether QK·µÄ saturates compute units  

---

# Day 4 ‚Äî PageAttention (KV Cache Layout)

## Objective

Understand how vLLM reduces KV memory waste via block-based paging.

---

## Tasks

- [ ] Study contiguous KV layout
- [ ] Study paged KV layout
- [ ] Design fixed-size KV blocks
- [ ] Implement logical-to-physical block mapping
- [ ] Write toy Triton PageAttention kernel
- [ ] Validate correctness
- [ ] Measure memory usage

---

## Key Concepts

- KV cache grows linearly with sequence length
- Contiguous layout leads to fragmentation
- Paged layout uses block tables
- Improves memory utilization for long-context inference

---

## Memory Comparison

| Mode        | KV Memory | Fragmentation | Best Use Case |
|------------|------------|---------------|---------------|
| Contiguous |            |               |               |
| Paged      |            |               |               |

---

## Deliverables

- triton_page_attention.py
- Memory usage comparison
- Short explanation of when paging helps

---

# Day 5 ‚Äî FlashAttention Theory & Tiling Design

## Objective

Understand IO-aware attention and why FlashAttention is faster.

---

## Tasks

- [ ] Study FlashAttention core ideas:
  - SRAM reuse
  - Block Q
  - Block K
  - Online softmax
  - Avoid n¬≤ materialization
- [ ] Derive why IO is reduced
- [ ] Compute arithmetic intensity before vs after tiling
- [ ] Design kernel parameters:
  - BLOCK_M
  - BLOCK_N
  - BLOCK_D
- [ ] Write kernel skeleton:
  - for k_tile in K:
  - compute qk_tile
  - update running max
  - update running sum
  - accumulate output


---

## Key Concepts

- Avoid writing S (n √ó n) to DRAM
- Online softmax enables single-pass normalization
- FlashAttention reduces memory traffic from O(n¬≤) ‚Üí O(nd)
- Kernel fusion increases arithmetic intensity

---

## Deliverables

- FlashAttention design document
- Arithmetic intensity comparison
- Kernel skeleton file

---

# Day 6 ‚Äî Triton FlashAttention (Mini Version)

## Objective

Implement a fused, tiled attention kernel in Triton.

---

## Tasks

- [ ] Implement tiled QK·µÄ
- [ ] Implement online softmax
- [ ] Fuse V multiplication
- [ ] Integrate into single kernel
- [ ] Validate correctness
- [ ] Benchmark vs naive implementation

---

## Final Comparison

| Impl   | ms | GB/s | TFLOPs | Speedup |
|--------|----|------|--------|---------|
| Naive  |    |      |        | 1.0x    |
| Flash  |    |      |        | 2.0x+   |

---

## Key Concepts

- SRAM reuse eliminates n¬≤ writes
- Fusion reduces global memory traffic
- FlashAttention shifts kernel toward compute-bound region
- IO-awareness matters more than reducing FLOPs

---

# End-of-Week Takeaways

- Attention performance is dominated by memory traffic
- Softmax is more memory-bound than QK·µÄ
- Kernel fusion drastically improves arithmetic intensity
- FlashAttention works by reducing IO, not reducing math
- Triton enables CUDA-level attention kernel design in Python


In [1]:
# triton_naive_attention_skeleton.py
# ============================================================
# Day 2 ‚Äî Triton Naive Attention Kernel (NO SOLUTION)
# Goal:
#   Implement naive attention:
#       attn = softmax(Q @ K.T) @ V
#   - Separate kernels for each step (QK^T, softmax, PV)
#   - Correctness validation vs PyTorch
#   - Support mask (causal / padding via additive -inf)
#   - Intentionally NOT optimized (no fusion, no FlashAttention tricks)
#
# Notes:
#   - This is a skeleton with TODOs only. Fill in kernels + launcher code.
#   - Keep correctness first; performance will be poor by design.
# ============================================================

import math
import torch
import triton
import triton.language as tl


# -----------------------------
# Utilities
# -----------------------------
def _assert_cuda(x: torch.Tensor, name: str):
    if not x.is_cuda:
        raise ValueError(f"{name} must be on CUDA, got {x.device}")
    if not x.is_contiguous():
        raise ValueError(f"{name} must be contiguous for this skeleton.")


def _make_additive_causal_mask(n: int, device, dtype):
    """
    Returns additive mask M in shape [n, n]:
      M[i, j] = 0 for j <= i
      M[i, j] = -inf for j > i
    Used as: scores = scores + M
    """
    # TODO: implement causal mask creation
    # raise NotImplementedError
    if device is None:
        device = "cpu"
    if dtype is None:
        dtype = torch.float32

    # Use the minimum finite value for the dtype to represent "-inf" in practice.
    # For fp16/bf16, true -inf exists, but using finfo.min is also common and safe.
    # neg_inf = torch.finfo(dtype).min

    # upper triangular (strictly above diagonal) => future positions
    # shape [n, n], True where j > i
    future = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), diagonal=1)

    # start from zeros, fill future with neg_inf
    mask = torch.zeros((n, n), device=device, dtype=dtype)
    mask = mask.masked_fill(future, -float("inf"))
    return mask


def _make_additive_padding_mask(valid_lens: torch.Tensor, n: int, device, dtype):
    """
    valid_lens: [B] or [n] style lengths; for this Day2 skeleton we keep it simple:
      - Assume a single sequence length n, and valid_lens is optional.
      - If you want per-row masking, expand to [n, n] additive mask.
      scores = scores + padding_mask
    """
    # TODO: implement padding mask (optional)
    # raise NotImplementedError
    if device is None:
        device = "cpu"
    if dtype is None:
        dtype = torch.float32

    neg_inf = torch.finfo(dtype).min

    # assume single length
    if valid_lens is None:
        return torch.zeros((n, n), device=device, dtype=dtype)

    L = int(valid_lens.item())

    # shape [n]
    key_positions = torch.arange(n, device=device)

    # True where j >= L
    invalid = key_positions >= L

    # expand to [n, n] (each row same mask)
    invalid = invalid.unsqueeze(0).expand(n, n)

    mask = torch.zeros((n, n), device=device, dtype=dtype)
    mask = mask.masked_fill(invalid, neg_inf)

    return mask


# ============================================================
# Kernel 1: Scores = Q @ K^T
# Q: [N, D], K: [N, D]  => Scores: [N, N]
# ============================================================
@triton.jit
def qk_t_kernel(
    q_ptr, k_ptr, s_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_qn: tl.constexpr, stride_qd: tl.constexpr,
    stride_kn: tl.constexpr, stride_kd: tl.constexpr,
    stride_sn: tl.constexpr, stride_sm: tl.constexpr,
    # Tile sizes (intentionally simple / naive)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Compute a tile of S = QK^T.
    Program ids map over (rows, cols) tiles of S.

    TODO:
    - Compute pid_m, pid_n
    - Compute row/col offsets
    - Load Q tile [BLOCK_M, BLOCK_K]
    - Load K tile [BLOCK_N, BLOCK_K] (note K^T => K rows act like cols)
    - Accumulate dot products
    - Store to S
    """
    # TODO: implement
    tl.static_assert(BLOCK_K <= D,
    "BLOCK_K must be <= D for this skeleton")
    # raise NotImplementedError

    # 2D program grid: each program handles a (BLOCK_M x BLOCK_N) tile of X
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)


    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    m_mask = offs_m < N
    n_mask = offs_n < N


    for k0 in range(0, D, BLOCK_K):
        d_offsets = k0 + offs_k
        d_mask = d_offsets < D

        q_ptrs = q_ptr + offs_m[:, None] * stride_qn + d_offsets[None, :] * stride_qd   # [BM, BK]
        k_ptrs = k_ptr + offs_n[:, None] * stride_kn + d_offsets[None, :] * stride_kd  # [BN,BK]

        # 2D masks for loads
        q_load_mask = m_mask[:, None] & d_mask[None, :]    # [BM, BK]
        k_load_mask = n_mask[:, None] & d_mask[None, :]    # [BK, BN]


        q_tile = tl.load(q_ptrs, mask = q_load_mask, other = 0.0)
        k_tile = tl.load(k_ptrs, mask = k_load_mask, other = 0.0)

        acc += tl.dot(q_tile, tl.trans(k_tile))

    s_ptrs = s_ptr + offs_m[:, None] * stride_sn + offs_n[None, :] * stride_sm
    tl.store(s_ptrs, acc, mask = m_mask[:, None] & n_mask[None, :])


# ============================================================
# Kernel 2: Softmax over each row of S (row-wise)
# S: [N, N] -> P: [N, N]
# Optional additive mask: M: [N, N] where invalid positions are -inf
# ============================================================
@triton.jit
def softmax_row_kernel(
    s_ptr, m_ptr, p_ptr,
    N: tl.constexpr,
    stride_sn: tl.constexpr, stride_sm: tl.constexpr,
    stride_mn: tl.constexpr, stride_mm: tl.constexpr,
    stride_pn: tl.constexpr, stride_pm: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_MASK: tl.constexpr,
):
    """
    Row-wise softmax:
      p[i, :] = softmax(s[i, :] + mask[i, :])
      p = s + m

    TODO:
    - Map program id to a row i
    - Load a row block of scores
    - If HAS_MASK, load mask and add
    - Numerically stable softmax:
        x = x - max(x)
        exp = tl.exp(x)
        denom = tl.sum(exp)
        p = exp / denom
    - Store p

    Notes:
    - This skeleton assumes N can be larger than BLOCK_N; you may loop over blocks
      or restrict this Day2 to N <= BLOCK_N initially.
    """
    # TODO: implement
    # raise NotImplementedError
    # if N > BLOCK_N:
    #     raise ValueError("Naive row-softmax requires N <= BLOCK_N")
    pid_row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_N)

    s_row_ptr = s_ptr + pid_row * stride_sn + offs * stride_sm
    p_row_ptr = p_ptr + pid_row * stride_pn + offs * stride_pm

    mask = offs < N
    s_row = tl.load(s_row_ptr, mask = mask, other = -float("inf")).to(tl.float32)
    if HAS_MASK:
        m_row_ptr = m_ptr + pid_row * stride_mn + offs * stride_mm
        m_row = tl.load(m_row_ptr, mask = mask, other = 0.0).to(tl.float32)
        s_row = s_row + m_row

    # stable softmax
    s_max = tl.max(s_row, axis = 0)
    s_row = s_row - s_max
    s_exp = tl.exp(s_row)
    s_sum = tl.sum(s_exp, axis=0)
    p_row = s_exp / s_sum

    tl.store(p_row_ptr, p_row, mask = mask)


# ============================================================
# Kernel 3: Out = P @ V
# P: [N, N], V: [N, D] -> O: [N, D]
# ============================================================
@triton.jit
def pv_kernel(
    p_ptr, v_ptr, o_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_pn: tl.constexpr, stride_pm: tl.constexpr,
    stride_vn: tl.constexpr, stride_vd: tl.constexpr,
    stride_on: tl.constexpr, stride_od: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Compute a tile of O = P V.

    TODO:
    - Program ids over (rows of O, cols of O)
    - Load P tile [BLOCK_M, BLOCK_K]
    - Load V tile [BLOCK_K, BLOCK_N] (here K dimension is N of P / V rows)
    - Accumulate
    - Store to O
    """
    # TODO: implement
    # raise NotImplementedError
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    m_mask = offs_m < N
    n_mask = offs_n < D

    for k0 in tl.static_range(0, N, BLOCK_K):
          # current K indices for this chunk
          k_offsets = k0 + offs_k
          k_mask = k_offsets < N

          # build pointer grids for this chunk
          p_ptrs = p_ptr + offs_m[:, None] * stride_pn + k_offsets[None, :] * stride_pm
          v_ptrs = v_ptr + k_offsets[:, None] * stride_vn + offs_n[None, :] * stride_vd

          # 2D masks for loads
          p_load_mask = m_mask[:, None] & k_mask[None, :]
          v_load_mask = k_mask[:, None] & n_mask[None, :]

          # masked loads: out-of-bounds => 0
          p_tile = tl.load(p_ptrs, mask=p_load_mask, other=0).to(tl.float32)
          v_tile = tl.load(v_ptrs, mask=v_load_mask, other=0).to(tl.float32)

          # accumulate (fp32)
          # tl.dot will typically accumulate in fp32 when acc is fp32
          acc += tl.dot(p_tile, v_tile)

    o_tile = acc
    o_ptrs = o_ptr + offs_m[:, None] * stride_on + offs_n[None, :] * stride_od
    tl.store(o_ptrs, o_tile, mask=m_mask[:,None] & n_mask[None,:])


# ============================================================
# Launchers (NO SOLUTION)
# ============================================================
def qk_t_triton(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    Compute S = Q @ K^T
    Q, K: [N, D] contiguous CUDA tensors
    Returns:
      S: [N, N]
    """
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    assert Q.shape == K.shape
    N, D = Q.shape

    S = torch.empty((N, N), device=Q.device, dtype=torch.float32)  # scores typically fp32

    # TODO:
    # - Choose BLOCK_M/BLOCK_N/BLOCK_K (naive defaults)
    # - Define grid mapping over tiles
    # - Call qk_t_kernel[grid](...)
    # raise NotImplementedError
    BLOCK_M = 128
    BLOCK_N = 128
    BLOCK_K = 64
    # if N > BLOCK_N:
    #     raise ValueError("Naive row-softmax requires N <= BLOCK_N")

    N, D = Q.shape
    N2, D2 = K.shape
    assert N == N2, "Q and K must have the same sequence length"
    assert D == D2, "Q and K must have the same head dimension"

    grid = (
        triton.cdiv(N, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    qk_t_kernel[grid](
        Q, K, S,
        N = N, D = D,
        stride_qn=Q.stride(0), stride_qd=Q.stride(1),
        stride_kn=K.stride(0), stride_kd=K.stride(1),
        stride_sn=S.stride(0), stride_sm=S.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K)

    return S


def softmax_triton(S: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
    """
    Compute P = softmax(S + mask) row-wise.
    S: [N, N]
    mask: [N, N] additive mask (0 or -inf). If None, no mask.
    Returns:
      P: [N, N] (same dtype as S or fp16/fp32 choice)
    """
    _assert_cuda(S, "S")
    N, N2 = S.shape
    assert N == N2

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N)

    P = torch.empty_like(S)

    # TODO:
    # - Choose BLOCK_N
    # - grid = (N,) one program per row (or per row-block)
    # - HAS_MASK constexpr
    # - Call softmax_row_kernel[grid](...)
    # raise NotImplementedError
    if N <= 128: BLOCK_N=128
    elif N <= 256: BLOCK_N=256
    elif N <= 512: BLOCK_N=512
    elif N <= 1024: BLOCK_N=1024
    else:
        raise ValueError(f"Naive row-softmax requires N <= 1024, got N={N}")
    grid = (N,)

    HAS_MASK = mask is not None
    m_ptr = mask if HAS_MASK else S
    stride_mn = mask.stride(0) if HAS_MASK else 0
    stride_mm = mask.stride(1) if HAS_MASK else 0


    softmax_row_kernel[grid](
    S, m_ptr, P,
    N,
    S.stride(0), S.stride(1),
    stride_mn, stride_mm,
    P.stride(0), P.stride(1),
    BLOCK_N=BLOCK_N,
    HAS_MASK=HAS_MASK,
)


    return P


def pv_triton(P: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """
    Compute O = P @ V
    P: [N, N]
    V: [N, D]
    Returns:
      O: [N, D]
    """
    _assert_cuda(P, "P")
    _assert_cuda(V, "V")
    N, N2 = P.shape
    assert N == N2
    assert V.shape[0] == N
    D = V.shape[1]

    O = torch.empty((N, D), device=V.device, dtype=torch.float32)

    # TODO:
    # - Choose BLOCK_M/BLOCK_N/BLOCK_K
    # - Define grid over O tiles
    # - Call pv_kernel[grid](...)
    # raise NotImplementedError
    BLOCK_M = 128
    BLOCK_N = 128
    BLOCK_K = 64

    grid = (
        triton.cdiv(N, BLOCK_M),  # pid_m
        triton.cdiv(D, BLOCK_N),  # pid_n
    )

    pv_kernel[grid](
        P, V, O,
        N=N, D=D,
        stride_pn=P.stride(0), stride_pm=P.stride(1),
        stride_vn=V.stride(0), stride_vd=V.stride(1),
        stride_on=O.stride(0), stride_od=O.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K)

    return O

def naive_attention_triton(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor | None = None):
    """
    Full naive attention:
      S = QK^T
      P = softmax(S + mask)
      O = P V
    """
    # TODO:
    # - Call qk_t_triton
    # - Call softmax_triton
    # - Call pv_triton
    # raise NotImplementedError
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    _assert_cuda(V, "V")
    assert Q.shape == K.shape, "Q and K must have shape [N, D]"
    assert Q.shape == V.shape, "For this toy naive version, assume V has shape [N, D]"
    N, D = Q.shape

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N), "mask must be [N, N] additive mask (0 / -inf)"

    # 1) Scores: S = Q @ K^T   -> [N, N] (often fp32)
    S = qk_t_triton(Q, K)

    # 2) Probabilities: P = softmax(S + mask)  -> [N, N]
    P = softmax_triton(S, mask)

    # 3) Output: O = P @ V     -> [N, D]
    O = pv_triton(P, V)

    return O


# ============================================================
# PyTorch reference & correctness checks (NO SOLUTION)
# ============================================================
def naive_attention_torch(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor | None = None):
    """
    Reference implementation in PyTorch:
      attn = softmax(Q @ K.T + mask) @ V
    """
    # TODO: implement torch reference (use float32 accumulation if needed)
    # raise NotImplementedError
    assert Q.shape == K.shape == V.shape, "This toy reference assumes Q,K,V are all [N, D]"
    N, D = Q.shape

    # Use fp32 for scores/softmax stability, regardless of input dtype
    Qf = Q.to(torch.float32)
    Kf = K.to(torch.float32)
    Vf = V.to(torch.float32)

    scores = Qf @ Kf.transpose(0, 1)  # [N, N]

    if mask is not None:
        assert mask.shape == (N, N), f"mask must be [N, N], got {mask.shape}"
        scores = scores + mask.to(torch.float32)

    P = torch.softmax(scores, dim=-1)  # row-wise softmax
    O = P @ Vf  # [N, D]

    return O

@torch.no_grad()
def check_correctness(device="cuda", dtype=torch.float16, N=256, D=64, use_mask=True):
    torch.manual_seed(0)
    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # TODO: create a causal mask (or padding mask)
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # TODO:
    # - Run torch reference
    # - Run triton naive attention
    # - Compare max/mean error
    # raise NotImplementedError
    # Reference (PyTorch)
    out_ref = naive_attention_torch(Q, K, V, mask=mask)

    # Triton naive
    out_tri = naive_attention_triton(Q, K, V, mask=mask)

    # Compare (cast both to fp32 for fair error)
    diff = (out_tri.to(torch.float32) - out_ref.to(torch.float32)).abs()
    max_err = diff.max().item()
    mean_err = diff.mean().item()
    rmse = torch.sqrt((diff * diff).mean()).item()

    print(f"[check_correctness] N={N}, D={D}, dtype={dtype}, use_mask={use_mask}")
    print(f"  max_abs_err : {max_err:.6e}")
    print(f"  mean_abs_err: {mean_err:.6e}")
    print(f"  rmse        : {rmse:.6e}")


@torch.no_grad()
def quick_bench(device="cuda", dtype=torch.float16, N=1024, D=64, iters=50, warmup=10, use_mask=False):
    """
    Simple benchmark harness (intentionally minimal).
    """
    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # TODO: create mask
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # TODO:
    # - Warmup runs
    # - Time with CUDA events
    # - Print ms/iter for torch vs triton
    # raise NotImplementedError
    torch.manual_seed(0)

    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # Additive causal mask: 0 or -inf
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # ----------------------------
    # Warmup
    # ----------------------------
    for _ in range(warmup):
        naive_attention_triton(Q, K, V, mask=mask)
        naive_attention_torch(Q, K, V, mask=mask)
    torch.cuda.synchronize()

    # ----------------------------
    # Benchmark Triton
    # ----------------------------
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(iters):
        naive_attention_triton(Q, K, V, mask=mask)
    end.record()
    torch.cuda.synchronize()

    triton_ms = start.elapsed_time(end) / iters

    # ----------------------------
    # Benchmark Torch
    # ----------------------------
    start.record()
    for _ in range(iters):
        naive_attention_torch(Q, K, V, mask=mask)
    end.record()
    torch.cuda.synchronize()

    torch_ms = start.elapsed_time(end) / iters

    speedup = torch_ms / triton_ms

    print(f"\n[quick_bench]")
    print(f"N={N}, D={D}, dtype={dtype}, mask={use_mask}")
    print(f"Triton: {triton_ms:.3f} ms")
    print(f"Torch : {torch_ms:.3f} ms")
    print(f"Speedup (Torch/Triton): {speedup:.2f}x")


if __name__ == "__main__":
    # TODO: run correctness + small bench
    check_correctness(N=128, D=64, use_mask=True)
    quick_bench(N=1024, D=64, use_mask=False)
    # pass


[check_correctness] N=128, D=64, dtype=torch.float16, use_mask=True
  max_abs_err : 1.192093e-06
  mean_abs_err: 6.595712e-08
  rmse        : 1.132886e-07

[quick_bench]
N=1024, D=64, dtype=torch.float16, mask=False
Triton: 4.712 ms
Torch : 0.145 ms
Speedup (Torch/Triton): 0.03x


In [2]:
# triton_naive_attention_skeleton.py
# ============================================================
# Day 2 ‚Äî Triton Naive Attention Kernel (NO SOLUTION)
# Goal:
  #improvement:1. change qk_t_kernel to uncoalscing reading on K demension(lower performance)
  # 2. softmax change the accuracy to bf16 0.04x to 0.14x
  # 3. softmax 2 pass kernel Pass1: row maxÔºõPass2: sum(exp) + store(lower performance)

# ============================================================

import math
import torch
import triton
import triton.language as tl


# -----------------------------
# Utilities
# -----------------------------
def _assert_cuda(x: torch.Tensor, name: str):
    if not x.is_cuda:
        raise ValueError(f"{name} must be on CUDA, got {x.device}")
    if not x.is_contiguous():
        raise ValueError(f"{name} must be contiguous for this skeleton.")


def _make_additive_causal_mask(n: int, device, dtype):
    """
    Returns additive mask M in shape [n, n]:
      M[i, j] = 0 for j <= i
      M[i, j] = -inf for j > i
    Used as: scores = scores + M
    """
    # TODO: implement causal mask creation
    # raise NotImplementedError
    if device is None:
        device = "cpu"
    if dtype is None:
        dtype = torch.float32

    # Use the minimum finite value for the dtype to represent "-inf" in practice.
    # For fp16/bf16, true -inf exists, but using finfo.min is also common and safe.
    # neg_inf = torch.finfo(dtype).min

    # upper triangular (strictly above diagonal) => future positions
    # shape [n, n], True where j > i
    future = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), diagonal=1)

    # start from zeros, fill future with neg_inf
    mask = torch.zeros((n, n), device=device, dtype=dtype)
    mask = mask.masked_fill(future, -float("inf"))
    return mask


def _make_additive_padding_mask(valid_lens: torch.Tensor, n: int, device, dtype):
    """
    valid_lens: [B] or [n] style lengths; for this Day2 skeleton we keep it simple:
      - Assume a single sequence length n, and valid_lens is optional.
      - If you want per-row masking, expand to [n, n] additive mask.
      scores = scores + padding_mask
    """
    # TODO: implement padding mask (optional)
    # raise NotImplementedError
    if device is None:
        device = "cpu"
    if dtype is None:
        dtype = torch.float32

    neg_inf = torch.finfo(dtype).min

    # assume single length
    if valid_lens is None:
        return torch.zeros((n, n), device=device, dtype=dtype)

    L = int(valid_lens.item())

    # shape [n]
    key_positions = torch.arange(n, device=device)

    # True where j >= L
    invalid = key_positions >= L

    # expand to [n, n] (each row same mask)
    invalid = invalid.unsqueeze(0).expand(n, n)

    mask = torch.zeros((n, n), device=device, dtype=dtype)
    mask = mask.masked_fill(invalid, neg_inf)

    return mask


# ============================================================
# Kernel 1: Scores = Q @ K^T
# Q: [N, D], K: [N, D]  => Scores: [N, N]
# ============================================================
@triton.jit
def qk_t_kernel(
    q_ptr, k_ptr, s_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_qn: tl.constexpr, stride_qd: tl.constexpr,
    stride_kn: tl.constexpr, stride_kd: tl.constexpr,
    stride_sn: tl.constexpr, stride_sm: tl.constexpr,
    # Tile sizes (intentionally simple / naive)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Compute a tile of S = QK^T.
    Program ids map over (rows, cols) tiles of S.

    TODO:
    - Compute pid_m, pid_n
    - Compute row/col offsets
    - Load Q tile [BLOCK_M, BLOCK_K]
    - Load K tile [BLOCK_N, BLOCK_K] (note K^T => K rows act like cols)
    - Accumulate dot products
    - Store to S
    """
    # TODO: implement
    tl.static_assert(BLOCK_K <= D,
    "BLOCK_K must be <= D for this skeleton")
    # raise NotImplementedError

    # 2D program grid: each program handles a (BLOCK_M x BLOCK_N) tile of X
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)


    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    m_mask = offs_m < N
    n_mask = offs_n < N


    for k0 in range(0, D, BLOCK_K):
        d_offsets = k0 + offs_k
        d_mask = d_offsets < D

        q_ptrs = q_ptr + offs_m[:, None] * stride_qn + d_offsets[None, :] * stride_qd   # [BM, BK]
        k_ptrs = k_ptr + offs_n[:, None] * stride_kn + d_offsets[None, :] * stride_kd  # [BN,BK]

        # 2D masks for loads
        q_load_mask = m_mask[:, None] & d_mask[None, :]    # [BM, BK]
        k_load_mask = n_mask[:, None] & d_mask[None, :]    # [BK, BN]


        q_tile = tl.load(q_ptrs, mask = q_load_mask, other = 0.0)
        k_tile = tl.load(k_ptrs, mask = k_load_mask, other = 0.0)

        acc += tl.dot(q_tile, tl.trans(k_tile))

    s_ptrs = s_ptr + offs_m[:, None] * stride_sn + offs_n[None, :] * stride_sm
    tl.store(s_ptrs, acc, mask = m_mask[:, None] & n_mask[None, :])


# ============================================================
# Kernel 2: Softmax over each row of S (row-wise)
# S: [N, N] -> P: [N, N]
# Optional additive mask: M: [N, N] where invalid positions are -inf
# ============================================================
@triton.jit
def softmax_row_kernel(
    s_ptr, m_ptr, p_ptr,
    N: tl.constexpr,
    stride_sn: tl.constexpr, stride_sm: tl.constexpr,
    stride_mn: tl.constexpr, stride_mm: tl.constexpr,
    stride_pn: tl.constexpr, stride_pm: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_MASK: tl.constexpr,
):
    """
    Row-wise softmax:
      p[i, :] = softmax(s[i, :] + mask[i, :])
      p = s + m

    TODO:
    - Map program id to a row i
    - Load a row block of scores
    - If HAS_MASK, load mask and add
    - Numerically stable softmax:
        x = x - max(x)
        exp = tl.exp(x)
        denom = tl.sum(exp)
        p = exp / denom
    - Store p

    Notes:
    - This skeleton assumes N can be larger than BLOCK_N; you may loop over blocks
      or restrict this Day2 to N <= BLOCK_N initially.
    """
    # TODO: implement
    # 2-pass row-wise softmax for a single row per program.

    # Pass 1: row_max = max_j (s[row, j] + mask[row, j])
    # Pass 2: row_sum = sum_j exp(s[row, j] + mask[row, j] - row_max)
    #         write p[row, j] = exp(...) / row_sum
    pid_row = tl.program_id(0)
    # ----------------------------
    # Pass 1: compute row max
    # ----------------------------
    row_max = tl.full((), -float("inf"), tl.float32)

    # loop over columns in blocks
    for c0 in range(0, N, BLOCK_N):
        offs = c0 + tl.arange(0, BLOCK_N)
        col_mask = offs < N

        s_row_ptr = s_ptr + pid_row * stride_sn + offs * stride_sm
        x = tl.load(s_row_ptr, mask=col_mask, other=-float("inf")).to(tl.float32)

        if HAS_MASK:
            m_row_ptr = m_ptr + pid_row * stride_mn + offs * stride_mm
            m = tl.load(m_row_ptr, mask=col_mask, other=0.0).to(tl.float32)
            x = x + m

        block_max = tl.max(x, axis=0)
        row_max = tl.maximum(row_max, block_max)

    # ----------------------------
    # Pass 2: compute sum(exp) and write output
    # ----------------------------
    row_sum = tl.zeros((), dtype=tl.float32)

    # 2a) sum
    for c0 in range(0, N, BLOCK_N):
        offs = c0 + tl.arange(0, BLOCK_N)
        col_mask = offs < N

        s_row_ptr = s_ptr + pid_row * stride_sn + offs * stride_sm
        x = tl.load(s_row_ptr, mask=col_mask, other=-float("inf")).to(tl.float32)

        if HAS_MASK:
            m_row_ptr = m_ptr + pid_row * stride_mn + offs * stride_mm
            m = tl.load(m_row_ptr, mask=col_mask, other=0.0).to(tl.float32)
            x = x + m

        x = x - row_max
        exp_x = tl.exp(x)
        row_sum += tl.sum(exp_x, axis=0)

    # 2b) write
    inv_sum = 1.0 / row_sum
    for c0 in range(0, N, BLOCK_N):
        offs = c0 + tl.arange(0, BLOCK_N)
        col_mask = offs < N

        s_row_ptr = s_ptr + pid_row * stride_sn + offs * stride_sm
        x = tl.load(s_row_ptr, mask=col_mask, other=-float("inf")).to(tl.float32)

        if HAS_MASK:
            m_row_ptr = m_ptr + pid_row * stride_mn + offs * stride_mm
            m = tl.load(m_row_ptr, mask=col_mask, other=0.0).to(tl.float32)
            x = x + m

        x = x - row_max
        p = tl.exp(x) * inv_sum  # fp32

        p_row_ptr = p_ptr + pid_row * stride_pn + offs * stride_pm
        tl.store(p_row_ptr, p.to(tl.float16), mask=col_mask)



# ============================================================
# Kernel 3: Out = P @ V
# P: [N, N], V: [N, D] -> O: [N, D]
# ============================================================
@triton.jit
def pv_kernel(
    p_ptr, v_ptr, o_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_pn: tl.constexpr, stride_pm: tl.constexpr,
    stride_vn: tl.constexpr, stride_vd: tl.constexpr,
    stride_on: tl.constexpr, stride_od: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Compute a tile of O = P V.

    TODO:
    - Program ids over (rows of O, cols of O)
    - Load P tile [BLOCK_M, BLOCK_K]
    - Load V tile [BLOCK_K, BLOCK_N] (here K dimension is N of P / V rows)
    - Accumulate
    - Store to O
    """
    # TODO: implement
    # raise NotImplementedError
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    m_mask = offs_m < N
    n_mask = offs_n < D

    for k0 in tl.static_range(0, N, BLOCK_K):
          # current K indices for this chunk
          k_offsets = k0 + offs_k
          k_mask = k_offsets < N

          # build pointer grids for this chunk
          p_ptrs = p_ptr + offs_m[:, None] * stride_pn + k_offsets[None, :] * stride_pm
          v_ptrs = v_ptr + k_offsets[:, None] * stride_vn + offs_n[None, :] * stride_vd

          # 2D masks for loads
          p_load_mask = m_mask[:, None] & k_mask[None, :]
          v_load_mask = k_mask[:, None] & n_mask[None, :]

          # masked loads: out-of-bounds => 0
          p_tile = tl.load(p_ptrs, mask=p_load_mask, other=0).to(tl.float32)
          v_tile = tl.load(v_ptrs, mask=v_load_mask, other=0).to(tl.float32)

          # accumulate (fp32)
          # tl.dot will typically accumulate in fp32 when acc is fp32
          acc += tl.dot(p_tile, v_tile)

    o_tile = acc
    o_ptrs = o_ptr + offs_m[:, None] * stride_on + offs_n[None, :] * stride_od
    tl.store(o_ptrs, o_tile, mask=m_mask[:,None] & n_mask[None,:])


# ============================================================
# Launchers (NO SOLUTION)
# ============================================================
def qk_t_triton(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    """
    Compute S = Q @ K^T
    Q, K: [N, D] contiguous CUDA tensors
    Returns:
      S: [N, N]
    """
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    assert Q.shape == K.shape
    N, D = Q.shape

    S = torch.empty((N, N), device=Q.device, dtype=torch.float32)  # scores typically fp32

    # TODO:
    # - Choose BLOCK_M/BLOCK_N/BLOCK_K (naive defaults)
    # - Define grid mapping over tiles
    # - Call qk_t_kernel[grid](...)
    # raise NotImplementedError
    BLOCK_M = 128
    BLOCK_N = 128
    BLOCK_K = 64
    # if N > BLOCK_N:
    #     raise ValueError("Naive row-softmax requires N <= BLOCK_N")

    N, D = Q.shape
    N2, D2 = K.shape
    assert N == N2, "Q and K must have the same sequence length"
    assert D == D2, "Q and K must have the same head dimension"

    grid = (
        triton.cdiv(N, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    qk_t_kernel[grid](
        Q, K, S,
        N = N, D = D,
        stride_qn=Q.stride(0), stride_qd=Q.stride(1),
        stride_kn=K.stride(0), stride_kd=K.stride(1),
        stride_sn=S.stride(0), stride_sm=S.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K)

    return S


def softmax_triton(S: torch.Tensor, mask: torch.Tensor | None) -> torch.Tensor:
    """
    Compute P = softmax(S + mask) row-wise.
    S: [N, N]
    mask: [N, N] additive mask (0 or -inf). If None, no mask.
    Returns:
      P: [N, N] (same dtype as S or fp16/fp32 choice)
    """
    _assert_cuda(S, "S")
    N, N2 = S.shape
    assert N == N2

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N)

    # P = torch.empty_like(S)
    P = torch.empty((N, N), device=S.device, dtype=torch.float16)

    # TODO:
    # - Choose BLOCK_N
    # - grid = (N,) one program per row (or per row-block)
    # - HAS_MASK constexpr
    # - Call softmax_row_kernel[grid](...)
    # raise NotImplementedError
    if N <= 128: BLOCK_N=128
    elif N <= 256: BLOCK_N=256
    elif N <= 512: BLOCK_N=512
    elif N <= 1024: BLOCK_N=1024
    else:
        raise ValueError(f"Naive row-softmax requires N <= 1024, got N={N}")
    grid = (N,)

    HAS_MASK = mask is not None
    m_ptr = mask if HAS_MASK else S
    stride_mn = mask.stride(0) if HAS_MASK else 0
    stride_mm = mask.stride(1) if HAS_MASK else 0


    softmax_row_kernel[grid](
    S, m_ptr, P,
    N,
    S.stride(0), S.stride(1),
    stride_mn, stride_mm,
    P.stride(0), P.stride(1),
    BLOCK_N=BLOCK_N,
    HAS_MASK=HAS_MASK,
)


    return P


def pv_triton(P: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """
    Compute O = P @ V
    P: [N, N]
    V: [N, D]
    Returns:
      O: [N, D]
    """
    _assert_cuda(P, "P")
    _assert_cuda(V, "V")
    N, N2 = P.shape
    assert N == N2
    assert V.shape[0] == N
    D = V.shape[1]

    O = torch.empty((N, D), device=V.device, dtype=torch.float32)

    # TODO:
    # - Choose BLOCK_M/BLOCK_N/BLOCK_K
    # - Define grid over O tiles
    # - Call pv_kernel[grid](...)
    # raise NotImplementedError
    BLOCK_M = 128
    BLOCK_N = 128
    BLOCK_K = 64

    grid = (
        triton.cdiv(N, BLOCK_M),  # pid_m
        triton.cdiv(D, BLOCK_N),  # pid_n
    )

    pv_kernel[grid](
        P, V, O,
        N=N, D=D,
        stride_pn=P.stride(0), stride_pm=P.stride(1),
        stride_vn=V.stride(0), stride_vd=V.stride(1),
        stride_on=O.stride(0), stride_od=O.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K)

    return O

def naive_attention_triton(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor | None = None):
    """
    Full naive attention:
      S = QK^T
      P = softmax(S + mask)
      O = P V
    """
    # TODO:
    # - Call qk_t_triton
    # - Call softmax_triton
    # - Call pv_triton
    # raise NotImplementedError
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    _assert_cuda(V, "V")
    assert Q.shape == K.shape, "Q and K must have shape [N, D]"
    assert Q.shape == V.shape, "For this toy naive version, assume V has shape [N, D]"
    N, D = Q.shape

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N), "mask must be [N, N] additive mask (0 / -inf)"

    # 1) Scores: S = Q @ K^T   -> [N, N] (often fp32)
    S = qk_t_triton(Q, K)

    # 2) Probabilities: P = softmax(S + mask)  -> [N, N]
    P = softmax_triton(S, mask)

    # 3) Output: O = P @ V     -> [N, D]
    O = pv_triton(P, V)

    return O


# ============================================================
# PyTorch reference & correctness checks (NO SOLUTION)
# ============================================================
def naive_attention_torch(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor | None = None):
    """
    Reference implementation in PyTorch:
      attn = softmax(Q @ K.T + mask) @ V
    """
    # TODO: implement torch reference (use float32 accumulation if needed)
    # raise NotImplementedError
    assert Q.shape == K.shape == V.shape, "This toy reference assumes Q,K,V are all [N, D]"
    N, D = Q.shape

    # Use fp32 for scores/softmax stability, regardless of input dtype
    Qf = Q.to(torch.float32)
    Kf = K.to(torch.float32)
    Vf = V.to(torch.float32)

    scores = Qf @ Kf.transpose(0, 1)  # [N, N]

    if mask is not None:
        assert mask.shape == (N, N), f"mask must be [N, N], got {mask.shape}"
        scores = scores + mask.to(torch.float32)

    P = torch.softmax(scores, dim=-1)  # row-wise softmax
    O = P @ Vf  # [N, D]

    return O

@torch.no_grad()
def check_correctness(device="cuda", dtype=torch.float16, N=256, D=64, use_mask=True):
    torch.manual_seed(0)
    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # TODO: create a causal mask (or padding mask)
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # TODO:
    # - Run torch reference
    # - Run triton naive attention
    # - Compare max/mean error
    # raise NotImplementedError
    # Reference (PyTorch)
    out_ref = naive_attention_torch(Q, K, V, mask=mask)

    # Triton naive
    out_tri = naive_attention_triton(Q, K, V, mask=mask)

    # Compare (cast both to fp32 for fair error)
    diff = (out_tri.to(torch.float32) - out_ref.to(torch.float32)).abs()
    max_err = diff.max().item()
    mean_err = diff.mean().item()
    rmse = torch.sqrt((diff * diff).mean()).item()

    print(f"[check_correctness] N={N}, D={D}, dtype={dtype}, use_mask={use_mask}")
    print(f"  max_abs_err : {max_err:.6e}")
    print(f"  mean_abs_err: {mean_err:.6e}")
    print(f"  rmse        : {rmse:.6e}")


@torch.no_grad()
def quick_bench(device="cuda", dtype=torch.float16, N=1024, D=64, iters=50, warmup=10, use_mask=False):
    """
    Simple benchmark harness (intentionally minimal).
    """
    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # TODO: create mask
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # TODO:
    # - Warmup runs
    # - Time with CUDA events
    # - Print ms/iter for torch vs triton
    # raise NotImplementedError
    torch.manual_seed(0)

    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # Additive causal mask: 0 or -inf
        mask = _make_additive_causal_mask(N, device=device, dtype=torch.float32)

    # ----------------------------
    # Warmup
    # ----------------------------
    for _ in range(warmup):
        naive_attention_triton(Q, K, V, mask=mask)
        naive_attention_torch(Q, K, V, mask=mask)
    torch.cuda.synchronize()

    # ----------------------------
    # Benchmark Triton
    # ----------------------------
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(iters):
        naive_attention_triton(Q, K, V, mask=mask)
    end.record()
    torch.cuda.synchronize()

    triton_ms = start.elapsed_time(end) / iters

    # ----------------------------
    # Benchmark Torch
    # ----------------------------
    start.record()
    for _ in range(iters):
        naive_attention_torch(Q, K, V, mask=mask)
    end.record()
    torch.cuda.synchronize()

    torch_ms = start.elapsed_time(end) / iters

    speedup = torch_ms / triton_ms

    print(f"\n[quick_bench]")
    print(f"N={N}, D={D}, dtype={dtype}, mask={use_mask}")
    print(f"Triton: {triton_ms:.3f} ms")
    print(f"Torch : {torch_ms:.3f} ms")
    print(f"Speedup (Torch/Triton): {speedup:.2f}x")


if __name__ == "__main__":
    # TODO: run correctness + small bench
    check_correctness(N=128, D=64, use_mask=True)
    quick_bench(N=1024, D=64, use_mask=False)
    # pass


[check_correctness] N=128, D=64, dtype=torch.float16, use_mask=True
  max_abs_err : 7.545948e-04
  mean_abs_err: 1.005516e-04
  rmse        : 1.451281e-04

[quick_bench]
N=1024, D=64, dtype=torch.float16, mask=False
Triton: 4.049 ms
Torch : 0.144 ms
Speedup (Torch/Triton): 0.04x


In [3]:
# day3_profile_bottleneck_skeleton.py
# ============================================================
# Day 3 ‚Äî Profiling + Bottleneck Analysis (NO SOLUTION)
#
# Goal:
#   Profile naive attention vs torch attention, identify bottlenecks.
#
# Tasks:
#   - Nsight Compute metrics:
#       * DRAM throughput
#       * SM efficiency
#       * Stall reasons
#   - Decide bottleneck:
#       * memory-bound?
#       * reduction-bound?
#   - Sweep:
#       * different block sizes
#       * different sequence lengths
#
# Outputs:
#   - Markdown table comparing naive vs torch (printed)
#   - Bottleneck analysis template (printed)
#
# Notes:
#   - Plug in your Day2 implementations:
#       naive_attention_triton(Q,K,V,mask,cfg)
#       naive_attention_torch(Q,K,V,mask)
# ============================================================

from __future__ import annotations
import json
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, List, Tuple

import torch


# ============================================================
# TODO: import your Day2 implementations
# ============================================================
def naive_attention_triton(Q, K, V, mask=None, cfg=None):
    # TODO: call your Triton naive attention implementation
    # raise NotImplementedError
    # return naive_attention_torch(Q, K, V, mask=mask)
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    _assert_cuda(V, "V")
    assert Q.shape == K.shape, "Q and K must have shape [N, D]"
    assert Q.shape == V.shape, "For this toy naive version, assume V has shape [N, D]"
    N, D = Q.shape

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N), "mask must be [N, N] additive mask (0 / -inf)"

    # 1) Scores: S = Q @ K^T   -> [N, N] (often fp32)
    S = qk_t_triton(Q, K)

    # 2) Probabilities: P = softmax(S + mask)  -> [N, N]
    P = softmax_triton(S, mask)

    # 3) Output: O = P @ V     -> [N, D]
    O = pv_triton(P, V)

    return O



def naive_attention_torch(Q, K, V, mask=None):
    # TODO: call your PyTorch reference implementation
    # raise NotImplementedError
    # return naive_attention_torch(Q, K, V, mask=mask)
    assert Q.shape == K.shape == V.shape, "This toy reference assumes Q,K,V are all [N, D]"
    N, D = Q.shape

    # Use fp32 for scores/softmax stability, regardless of input dtype
    Qf = Q.to(torch.float32)
    Kf = K.to(torch.float32)
    Vf = V.to(torch.float32)

    scores = Qf @ Kf.transpose(0, 1)  # [N, N]

    if mask is not None:
        assert mask.shape == (N, N), f"mask must be [N, N], got {mask.shape}"
        scores = scores + mask.to(torch.float32)

    P = torch.softmax(scores, dim=-1)  # row-wise softmax
    O = P @ Vf  # [N, D]

    return O



# ============================================================
# Config definition for block size sweep
# ============================================================
@dataclass(frozen=True)
class TritonNaiveCfg:
    BLOCK_M: int
    BLOCK_N: int
    BLOCK_K: int
    SOFTMAX_BLOCK: int
    num_warps: int = 4


# ============================================================
# Benchmark utilities
# ============================================================
def cuda_time_ms(fn, iters=30, warmup=10) -> float:
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()

    return start.elapsed_time(end) / iters


def gflops_qk_pv(N: int, D: int) -> float:
    flops = 4.0 * N * N * D
    return flops / 1e9


def estimate_bytes(N: int, D: int, elem_bytes: int = 2) -> int:
    # TODO: refine if using fp32 intermediate
    qkv = 3 * N * D * elem_bytes
    s_mat = 2 * N * N * elem_bytes
    p_mat = 2 * N * N * elem_bytes
    out = N * D * elem_bytes
    return qkv + s_mat + p_mat + out


# ============================================================
# Experiment execution
# ============================================================
@dataclass
class ResultRow:
    impl: str
    N: int
    D: int
    cfg: Optional[Dict[str, Any]]
    ms: float
    gflops: float
    est_gbs: float


def run_case(N: int, D: int, cfg: TritonNaiveCfg):
    device = "cuda"
    dtype = torch.float16

    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    # --- Torch baseline ---
    def fn_torch():
        return naive_attention_torch(Q, K, V, mask=None)

    ms_torch = cuda_time_ms(fn_torch)
    gflops = gflops_qk_pv(N, D)
    bytes_est = estimate_bytes(N, D)

    torch_row = ResultRow(
        impl="torch",
        N=N,
        D=D,
        cfg=None,
        ms=ms_torch,
        gflops=gflops / (ms_torch / 1e3),
        est_gbs=(bytes_est / (ms_torch / 1e3)) / 1e9,
    )

    # --- Triton naive ---
    def fn_triton():
        return naive_attention_triton(Q, K, V, mask=None, cfg=cfg)

    ms_triton = cuda_time_ms(fn_triton)

    triton_row = ResultRow(
        impl="naive_triton",
        N=N,
        D=D,
        cfg=asdict(cfg),
        ms=ms_triton,
        gflops=gflops / (ms_triton / 1e3),
        est_gbs=(bytes_est / (ms_triton / 1e3)) / 1e9,
    )

    return torch_row, triton_row


def sweep(seq_lens: List[int], D: int, cfgs: List[TritonNaiveCfg]):
    rows: List[ResultRow] = []
    for N in seq_lens:
        for cfg in cfgs:
            torch_row, triton_row = run_case(N, D, cfg)
            rows.append(torch_row)
            rows.append(triton_row)
    return rows


# ============================================================
# Output formatting
# ============================================================
def print_markdown_table(rows: List[ResultRow]):
    print("\n# Day 3 Results (Naive Triton vs Torch)\n")
    print("| impl | N | D | cfg | ms | GFLOP/s | est_GB/s | speedup_vs_torch |")
    print("|------|---|---|-----|----|---------|----------|------------------|")

    torch_map = {(r.N, r.D): r.ms for r in rows if r.impl == "torch"}

    for r in rows:
        base = torch_map.get((r.N, r.D), None)
        speedup = base / r.ms if (base and r.impl != "torch") else 1.0

        cfg_str = "-"
        if r.cfg:
            cfg_str = f"BM={r.cfg['BLOCK_M']},BN={r.cfg['BLOCK_N']},BK={r.cfg['BLOCK_K']},SB={r.cfg['SOFTMAX_BLOCK']},w={r.cfg['num_warps']}"

        print(f"| {r.impl} | {r.N} | {r.D} | {cfg_str} | "
              f"{r.ms:.4f} | {r.gflops:.2f} | {r.est_gbs:.2f} | "
              f"{speedup:.2f}x |")


def print_bottleneck():
    print("\n\n# Bottleneck Analysis (Fill After Nsight Compute)\n")
    print("## Nsight Compute Observations")
    print("- DRAM throughput (% peak): TODO")
    print("- SM throughput (% peak): TODO")
    print("- Dominant stall reasons:")
    print("  - long scoreboard: TODO")
    print("  - memory dependency: TODO")
    print("  - barrier: TODO")
    print("  - math pipe throttle: TODO\n")

    print("## Bottleneck Classification")
    print("- [ ] Memory-bound")
    print("- [ ] Reduction-bound")
    print("- [ ] Compute-bound\n")

    print("## Interpretation")
    print("- Naive attention materializes N√óN matrices.")
    print("- Softmax requires multiple passes (max/sum/normalize).")
    print("- Heavy DRAM traffic likely dominates performance.\n")

    print("## Next Steps")
    print("- Tune block sizes.")
    print("- Increase arithmetic intensity.")
    print("- Consider fusion (FlashAttention).\n")


# ============================================================
# Main
# ============================================================
def main():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA required.")

    seq_lens = [256, 512, 1024]  # TODO: extend if desired
    D = 64

    cfgs = [
        TritonNaiveCfg(64, 64, 32, 256, 4),
        TritonNaiveCfg(128, 64, 32, 256, 4),
        TritonNaiveCfg(64, 128, 32, 512, 8),
    ]

    rows = sweep(seq_lens, D, cfgs)

    print_markdown_table(rows)
    # print_bottleneck()


if __name__ == "__main__":
    main()



# Day 3 Results (Naive Triton vs Torch)

| impl | N | D | cfg | ms | GFLOP/s | est_GB/s | speedup_vs_torch |
|------|---|---|-----|----|---------|----------|------------------|
| torch | 256 | 64 | - | 0.0944 | 177.81 | 6.95 | 1.00x |
| naive_triton | 256 | 64 | BM=64,BN=64,BK=32,SB=256,w=4 | 1.0278 | 16.32 | 0.64 | 0.09x |
| torch | 256 | 64 | - | 0.0904 | 185.62 | 7.25 | 1.00x |
| naive_triton | 256 | 64 | BM=128,BN=64,BK=32,SB=256,w=4 | 1.0256 | 16.36 | 0.64 | 0.09x |
| torch | 256 | 64 | - | 0.0884 | 189.74 | 7.41 | 1.00x |
| naive_triton | 256 | 64 | BM=64,BN=128,BK=32,SB=512,w=8 | 1.0265 | 16.34 | 0.64 | 0.09x |
| torch | 512 | 64 | - | 0.1027 | 653.30 | 22.97 | 1.00x |
| naive_triton | 512 | 64 | BM=64,BN=64,BK=32,SB=256,w=4 | 1.5901 | 42.20 | 1.48 | 0.07x |
| torch | 512 | 64 | - | 0.1031 | 650.90 | 22.88 | 1.00x |
| naive_triton | 512 | 64 | BM=128,BN=64,BK=32,SB=256,w=4 | 1.5892 | 42.23 | 1.48 | 0.07x |
| torch | 512 | 64 | - | 0.1036 | 647.97 | 22.78 | 1.00x |
| naive_trito

In [4]:
# day4_paged_attention_toy_skeleton.py
# ============================================================
# Day 4 ‚Äî PageAttention (Toy) (NO SOLUTION)
#
# Goal:
#   Understand vLLM-style KV cache paging by building a toy PageAttention in Triton.
#
# What you'll implement (toy scope):
#   - Two KV cache modes:
#       (1) Contiguous KV: K,V stored as [T, D] for each sequence (single sequence toy)
#       (2) Paged KV: K,V stored in fixed-size blocks; a block table maps logical blocks to physical blocks
#   - A toy attention computation that reads K,V via the selected layout:
#       out = softmax(Q @ K^T + mask) @ V
#
# Tasks:
#   - Study: contiguous KV vs paged KV
#   - Design: KV block layout + block table
#   - Implement: Triton PageAttention (toy)
#   - Validate correctness vs torch reference
#   - Memory usage stats (allocated bytes, fragmentation estimate)
#
# Notes:
#   - This is a skeleton with TODOs only.
#   - Keep it SIMPLE: single-head, single sequence, fp16 inputs, fp32 accum.
#   - You can extend later to multi-head/batch.
# ============================================================

from __future__ import annotations
import math
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, Tuple

import torch
import triton
import triton.language as tl


# ============================================================
# Data structures
# ============================================================
@dataclass(frozen=True)
class PageCfg:
    # page/block size in tokens
    BLOCK_T: int
    # head dim
    D: int
    # number of physical blocks allocated in the KV pool
    NUM_PHYS_BLOCKS: int

    # toy kernel tiling knobs (optional)
    BLOCK_M: int = 64      # query rows (here usually 1 query, but keep generic)
    BLOCK_N: int = 128     # keys columns tile
    num_warps: int = 4


@dataclass
class MemStats:
    mode: str
    logical_T: int
    D: int
    block_T: int
    num_logical_blocks: int
    num_phys_blocks: int
    kv_bytes_allocated: int
    kv_bytes_used: int
    fragmentation_bytes: int
    fragmentation_ratio: float


# ============================================================
# Helper: build additive causal mask (optional)
# ============================================================
def make_additive_causal_mask(T: int, device, dtype=torch.float32) -> torch.Tensor:
    """
    Returns [T, T] additive causal mask:
      0 for j <= i
      -inf for j > i
    """
    # TODO: implement
    # raise NotImplementedError


# ============================================================
# Contiguous KV layout (toy)
# ============================================================
def alloc_contiguous_kv(T: int, D: int, device="cuda", dtype=torch.float16) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Allocate contiguous K,V as [T, D]
    """
    K = torch.empty((T, D), device=device, dtype=dtype)
    V = torch.empty((T, D), device=device, dtype=dtype)
    return K, V


# ============================================================
# Paged KV layout (toy)
# ============================================================
def alloc_paged_kv_pool(num_phys_blocks: int, block_T: int, D: int, device="cuda", dtype=torch.float16):
    """
    Allocate a KV pool with fixed blocks:
      K_pool: [num_phys_blocks, block_T, D]
      V_pool: [num_phys_blocks, block_T, D]
    """
    K_pool = torch.empty((num_phys_blocks, block_T, D), device=device, dtype=dtype)
    V_pool = torch.empty((num_phys_blocks, block_T, D), device=device, dtype=dtype)
    return K_pool, V_pool


def build_block_table(num_logical_blocks: int, num_phys_blocks: int, device="cuda") -> torch.Tensor:
    """
    Block table maps logical block idx -> physical block idx:
      block_table[lb] = pb

    For toy:
      - you can map 0..L-1 to some subset of physical blocks
      - support non-contiguous placement to mimic fragmentation avoidance

    Returns:
      block_table: [num_logical_blocks] int32
    """
    # TODO: implement mapping strategy (e.g., random perm, or identity)
    # raise NotImplementedError

    assert num_logical_blocks <= num_phys_blocks, "num_logical_blocks must be <= num_phys_blocks"
    assert num_logical_blocks >= 0, "num_logical_blocks must be >= 0"
    assert num_phys_blocks > 0, "num_phys_blocks must be > 0"

    if num_logical_blocks == 0:
        return torch.empty((0,), device=device, dtype=torch.int32)

    # random non-contiguous mapping (deterministic if you set torch.manual_seed outside)
    perm = torch.randperm(num_phys_blocks, device=device, dtype=torch.int64)
    block_table = perm[:num_logical_blocks].to(torch.int32)

    return block_table


def write_tokens_to_paged_kv(
    K_tokens: torch.Tensor, V_tokens: torch.Tensor,
    K_pool: torch.Tensor, V_pool: torch.Tensor,
    block_table: torch.Tensor, block_T: int
):
    """
    Scatter logical tokens [T, D] into paged pools using block_table.
    This simulates how vLLM stores KV into pages.

    Inputs:
      K_tokens, V_tokens: [T, D]
      K_pool, V_pool: [PB, block_T, D]
      block_table: [LB] (logical blocks)
      block_T: tokens per block

    TODO:
      - For each logical token index t:
          lb = t // block_T
          off = t % block_T
          pb = block_table[lb]
          write into K_pool[pb, off, :]
    """
    # TODO: implement scatter

    # -----------------------------
    # Step 1: Validate inputs (shapes / dtypes / devices)
    # -----------------------------

    assert isinstance(K_tokens, torch.Tensor) and isinstance(V_tokens, torch.Tensor)
    assert isinstance(K_pool, torch.Tensor) and isinstance(V_pool, torch.Tensor)
    assert isinstance(block_table, torch.Tensor)

    # token tensors: [T, D]
    assert K_tokens.ndim == 2, f"K_tokens must be [T,D], got {K_tokens.shape}"
    assert V_tokens.ndim == 2, f"V_tokens must be [T,D], got {V_tokens.shape}"
    assert K_tokens.shape == V_tokens.shape, f"K_tokens and V_tokens must match, got {K_tokens.shape} vs {V_tokens.shape}"
    T, D = K_tokens.shape

    # pool tensors: [PB, block_T, D]
    assert K_pool.ndim == 3, f"K_pool must be [PB,block_T,D], got {K_pool.shape}"
    assert V_pool.ndim == 3, f"V_pool must be [PB,block_T,D], got {V_pool.shape}"
    assert K_pool.shape == V_pool.shape, f"K_pool and V_pool must match, got {K_pool.shape} vs {V_pool.shape}"
    PB, BT, Dp = K_pool.shape

    assert BT == block_T, f"pool block_T={BT} must equal block_T arg={block_T}"
    assert Dp == D, f"pool D={Dp} must equal tokens D={D}"

    # device consistency
    dev = K_pool.device
    assert K_tokens.device == dev, f"K_tokens device {K_tokens.device} must match K_pool device {dev}"
    assert V_tokens.device == dev, f"V_tokens device {V_tokens.device} must match K_pool device {dev}"
    assert V_pool.device == dev, f"V_pool device {V_pool.device} must match K_pool device {dev}"
    assert block_table.device == dev, f"block_table device {block_table.device} must match K_pool device {dev}"

    # dtype sanity (toy: usually fp16/bf16 for K/V pools)
    assert K_tokens.dtype == K_pool.dtype, f"K_tokens dtype {K_tokens.dtype} must match K_pool dtype {K_pool.dtype}"
    assert V_tokens.dtype == V_pool.dtype, f"V_tokens dtype {V_tokens.dtype} must match V_pool dtype {V_pool.dtype}"

    # block_table: [LB] integer
    assert block_table.ndim == 1, f"block_table must be 1D [LB], got {block_table.shape}"
    assert block_table.dtype in (torch.int32, torch.int64), f"block_table must be int32/int64, got {block_table.dtype}"

    # block_table must cover all logical blocks for T tokens
    LB = (T + block_T - 1) // block_T
    assert block_table.numel() >= LB, f"block_table too short: need LB={LB}, got {block_table.numel()}"

    # pb range check (optional but strongly recommended)
    # Only check the portion we will actually use (first LB entries)
    bt_used = block_table[:LB].to(torch.int64)
    assert int(bt_used.min().item()) >= 0, "block_table contains negative physical block id"
    assert int(bt_used.max().item()) < PB, f"block_table contains pb >= PB (PB={PB})"

    # Check uniqueness
    unique_pb = torch.unique(bt_used)

    assert unique_pb.numel() == bt_used.numel(), (
        "block_table contains duplicate physical block ids "
        "(would cause KV overwrite)"
    )

    # -----------------------------
    # Step 2: Scatter write (pb scalar + write V_pool too)
    # -----------------------------

    for t in range(T):
        lb = t // block_T
        off = t % block_T

        pb = int(block_table[lb].item())

        K_pool[pb, off, :] = K_tokens[t, :]
        V_pool[pb, off, :] = V_tokens[t, :]


    # -----------------------------
    # Step3: zero out unused slots in the last block (debug-friendly)
    # -----------------------------
    valid = T % block_T
    if valid > 0:
        lb_last = LB - 1  # last logical block
        pb_last = int(block_table[lb_last].item())

        pb_last = int(block_table[lb_last].item())
        K_pool[pb_last, valid:, :].zero_()
        V_pool[pb_last, valid:, :].zero_()

    # raise NotImplementedError






# ============================================================
# Memory stats
# ============================================================
def mem_stats_contiguous(T: int, D: int, dtype=torch.float16) -> MemStats:
    elem = torch.tensor([], dtype=dtype).element_size()
    used = 2 * T * D * elem  # K and V
    return MemStats(
        mode="contiguous",
        logical_T=T,
        D=D,
        block_T=0,
        num_logical_blocks=0,
        num_phys_blocks=0,
        kv_bytes_allocated=used,
        kv_bytes_used=used,
        fragmentation_bytes=0,
        fragmentation_ratio=0.0,
    )


def mem_stats_paged(T: int, cfg: PageCfg, dtype=torch.float16) -> MemStats:
    elem = torch.tensor([], dtype=dtype).element_size()
    block_T = cfg.BLOCK_T
    LB = (T + block_T - 1) // block_T
    allocated = 2 * cfg.NUM_PHYS_BLOCKS * block_T * cfg.D * elem
    used = 2 * T * cfg.D * elem
    frag = allocated - used
    return MemStats(
        mode="paged",
        logical_T=T,
        D=cfg.D,
        block_T=block_T,
        num_logical_blocks=LB,
        num_phys_blocks=cfg.NUM_PHYS_BLOCKS,
        kv_bytes_allocated=allocated,
        kv_bytes_used=used,
        fragmentation_bytes=max(0, frag),
        fragmentation_ratio=max(0.0, frag / max(1, allocated)),
    )


# ============================================================
# Triton: toy "paged gather" helper (kernel-side addressing)
# ============================================================
@triton.jit
def paged_kv_gather_kernel(
    # pointers
    k_pool_ptr, v_pool_ptr,
    block_table_ptr,
    # output contiguous buffers for debugging (optional)
    k_out_ptr, v_out_ptr,
    # sizes
    T: tl.constexpr, D: tl.constexpr,
    BLOCK_T: tl.constexpr,
    # strides (pool is [PB, BLOCK_T, D])
    stride_kpb: tl.constexpr, stride_kpt: tl.constexpr, stride_kd: tl.constexpr,
    stride_vpb: tl.constexpr, stride_vpt: tl.constexpr, stride_vd: tl.constexpr,
    stride_out_t: tl.constexpr, stride_out_d: tl.constexpr,
):
    """
    OPTIONAL helper kernel:
      Gather paged K/V into contiguous [T, D] buffers.
    This is NOT how vLLM does it (they avoid materializing), but useful for debugging.

    TODO:
      - Map program id to a token block
      - For each token t in the block:
          lb = t // BLOCK_T
          off = t % BLOCK_T
          pb = block_table[lb]
          load K_pool[pb, off, :]
          store into k_out[t, :]
      - Similarly for V
    """
    # TODO: implement (optional)
    # raise NotImplementedError

    # ---- program id -> logical block id ----
    pid = tl.program_id(axis=0)
    t_idx = pid * BLOCK_T + tl.arange(0, BLOCK_T)
    mask_t = t_idx < T
    mask_td = mask_t[:, None]

    # ---- logical block -> physical block ----
    pb = tl.load(block_table_ptr + pid).to(tl.int64)

    # ---- offsets within a block (tokens) and within a vector (D) ----
    offs_t_in_block = tl.arange(0, BLOCK_T)
    offs_d = tl.arange(0, D)

    # ---- build pool pointers [BLOCK_T, D] ----
    k_ptrs = k_pool_ptr + pb * stride_kpb + offs_t_in_block[:, None] * stride_kpt + offs_d[None, :] * stride_kd
    v_ptrs = v_pool_ptr + pb * stride_vpb + offs_t_in_block[:, None] * stride_vpt + offs_d[None, :] * stride_vd

    # ---- build output pointers [BLOCK_T, D] ----
    k_out_ptrs = k_out_ptr + t_idx[:, None] * stride_out_t + offs_d[None, :] * stride_out_d
    v_out_ptrs = v_out_ptr + t_idx[:, None] * stride_out_t + offs_d[None, :] * stride_out_d

    # ---- load from pool and store to contiguous outputs ----
    k = tl.load(k_ptrs, mask=mask_td, other=0.0)
    v = tl.load(v_ptrs, mask=mask_td, other=0.0)

    tl.store(k_out_ptrs, k, mask=mask_td)
    tl.store(v_out_ptrs, v, mask=mask_td)


# ============================================================
# Triton PageAttention (toy)
# ============================================================
@triton.jit
def page_attention_kernel(
    q_ptr,                    # [M, D]
    k_pool_ptr, v_pool_ptr,   # [PB, BLOCK_T, D]
    block_table_ptr,          # [LB]
    mask_ptr,                 # [M, T] additive mask (optional, can be null/dummy if HAS_MASK=False)
    out_ptr,                  # [M, D]
    # sizes
    M: tl.constexpr,
    T: tl.constexpr,
    D: tl.constexpr,
    BLOCK_T: tl.constexpr,
    # strides for Q [M, D]
    stride_qm: tl.constexpr, stride_qd: tl.constexpr,
    # strides for pool [PB, BLOCK_T, D]
    stride_kpb: tl.constexpr, stride_kpt: tl.constexpr, stride_kd: tl.constexpr,
    stride_vpb: tl.constexpr, stride_vpt: tl.constexpr, stride_vd: tl.constexpr,
    # strides for mask [M, T]
    stride_mm: tl.constexpr, stride_mt: tl.constexpr,
    # strides for Out [M, D]
    stride_om: tl.constexpr, stride_od: tl.constexpr,
    # tiling
    BLOCK_N: tl.constexpr,
    HAS_MASK: tl.constexpr,
):
    """
    Toy paged attention:
      out[m, :] = softmax( q[m,:] @ K[:T,:]^T + mask ) @ V[:T,:]

    Constraints / simplifying assumptions:
      - Single head
      - Uses block_table to locate K/V blocks
      - Does NOT attempt FlashAttention fusion tricks (this is day4, not day6)
      - You may implement:
          (A) full materialization of scores for toy correctness
          or
          (B) streaming softmax (more advanced, optional)
    Skeleton expects TODOs only.

    TODO:
      1) Load q vector for row m
      2) Iterate over key tiles t0:t0+BLOCK_N
          - For each token t in tile:
              lb = t // BLOCK_T
              off = t % BLOCK_T
              pb = block_table[lb]
              load k = K_pool[pb, off, :]
              compute score = dot(q, k)
              apply mask if HAS_MASK
          - softmax over T tokens (requires reduction across tiles)
      3) Weighted sum over V similarly:
          out = sum_j p_j * v_j

    Because softmax needs a global normalization across all T,
    you will likely need:
      - a two-pass approach (scores -> softmax -> PV), OR
      - an online softmax approach.

    For this Day4 toy, pick the simplest correct approach.
    """
    # TODO: implement
    # raise NotImplementedError
    m = tl.program_id(0)


    # ---- load q[m, :] ----
    d = tl.arange(0, D)
    mask_d = d < D
    q = tl.load(q_ptr + m * stride_qm + d * stride_qd, mask = mask_d, other = 0.0).to(tl.float32)
    inv_sqrt_d = 1.0 / tl.sqrt(tl.full([], D, tl.float32))

    # Pass 1: compute global max score for numerical stability
    max_s = tl.full([], -float("inf"), tl.float32)

    for t0 in range(0, T, BLOCK_N):
        t = t0 + tl.arange(0, BLOCK_N)               # [BN]
        mask_t = t < T

        #page address translation
        lb = t // BLOCK_T                             # [BN]
        off = t % BLOCK_T                             # [BN]

        pb = tl.load(block_table_ptr + lb, mask=mask_t, other=0).to(tl.int64)  # [BN]


        # load K tile: [BN, D]
        k_ptrs = (k_pool_ptr + pb[:, None] * stride_kpb + off[:, None] * stride_kpt + d[None, :] * stride_kd)
        k = tl.load(k_ptrs, mask=mask_t[:, None], other=0.0).to(tl.float32)

        #score
        s = tl.sum(k * q[None, :], axis=1) * inv_sqrt_d   # [BN]

        if HAS_MASK:
          mvals = tl.load(mask_ptr + m * stride_mm + t * stride_mt, mask=mask_t, other=-float("inf")).to(tl.float32)
          s = s + mvals

        # invalid tokens -> -inf so they don't affect max
        s = tl.where(mask_t, s, -float("inf"))
        max_s = tl.maximum(max_s, tl.max(s, axis=0))

    # Pass 2: compute sumexp and accumulate PV
    denom = tl.full([], 0.0, tl.float32)
    out = tl.zeros([D], dtype=tl.float32)

    for t0 in range(0, T, BLOCK_N):
        t = t0 + tl.arange(0, BLOCK_N)
        mask_t = t < T

        lb = t // BLOCK_T
        off = t % BLOCK_T

        pb = tl.load(block_table_ptr + lb, mask=mask_t, other=0).to(tl.int64)

        # load K tile: [BN, D]
        k_ptrs = (k_pool_ptr + pb[:, None] * stride_kpb + off[:, None] * stride_kpt + d[None, :] * stride_kd)
        k = tl.load(k_ptrs, mask=mask_t[:, None], other=0.0).to(tl.float32)

        s = tl.sum(k * q[None, :], axis=1) * inv_sqrt_d  # [BN]

        if HAS_MASK:
            mvals = tl.load(mask_ptr + m * stride_mm + t * stride_mt, mask=mask_t, other=-float("inf")).to(tl.float32)
            s = s + mvals

        s = tl.where(mask_t, s, -float("inf"))

        # exp(score - max)
        w = tl.exp(s - max_s)                           # [BN]
        w = tl.where(mask_t, w, 0.0)

        denom += tl.sum(w, axis=0)

        # load V: [BN, D]
        v_ptrs = (v_pool_ptr + pb[:, None] * stride_vpb + off[:, None] * stride_vpt + d[None, :] * stride_vd)
        v = tl.load(v_ptrs, mask=mask_t[:, None], other=0.0).to(tl.float32)

        # out += sum_j w_j * v_j
        out += tl.sum(v * w[:, None], axis=0)

    # normalize
    denom = tl.maximum(denom, 1e-9)
    out = out / denom

    # store
    tl.store(out_ptr + m * stride_om + d * stride_od, out.to(tl.float16))

# ============================================================
# Torch references
# ============================================================
def attention_torch_contiguous(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: Optional[torch.Tensor] = None):
    """
    Reference attention for contiguous KV:
      out = softmax(Q @ K.T + mask) @ V
    """
    # TODO: implement (use fp32 scores for stability)
    # raise NotImplementedError
    d = Q.shape[-1]


    # fp32 compute for stability
    scores = torch.matmul(Q.float(), K.float().transpose(-1, -2))
    scores /= math.sqrt(d)

    if mask is not None:
        scores += mask.to(scores.dtype)

    scores = scores.softmax(dim=-1)
    out = torch.matmul(scores, V.float())

    return out.to(Q.dtype)



def attention_torch_from_paged(
    Q: torch.Tensor,
    K_pool: torch.Tensor, V_pool: torch.Tensor,
    block_table: torch.Tensor, T: int, cfg: PageCfg,
    mask: Optional[torch.Tensor] = None
):
    """
    Reference attention by first gathering paged KV into contiguous K,V (for correctness only).
    Then run standard torch attention.
    """
    # TODO:
    # - gather K,V into [T,D] using block_table
    # - call attention_torch_contiguous
    # raise NotImplementedError


    # Shapes (toy):
    # Q:        [M, D]
    # K_pool:   [PB, BLOCK_T, D]
    # V_pool:   [PB, BLOCK_T, D]
    # block_table: [LB] where LB = ceil(T / BLOCK_T)
    # mask:     None or [M, T] additive mask

    assert Q.ndim == 2, f"Q must be [M,D], got {Q.shape}"
    assert K_pool.ndim == 3 and V_pool.ndim == 3, f"K_pool/V_pool must be [PB,BLOCK_T,D]"
    assert K_pool.shape == V_pool.shape, "K_pool and V_pool must have same shape"
    PB, BLOCK_T, D = K_pool.shape

    assert D == cfg.D, f"Pool D={D} must match cfg.D={cfg.D}"
    assert BLOCK_T == cfg.BLOCK_T, f"Pool BLOCK_T={BLOCK_T} must match cfg.BLOCK_T={cfg.BLOCK_T}"
    assert Q.shape[1] == D, f"Q D={Q.shape[1]} must match pool D={D}"
    assert 0 < T <= cfg.BLOCK_T * K_pool.shape[0] * 10_000, "T looks unreasonable for given pool (sanity check)"


    # block_table length must cover all logical blocks needed for T tokens
    LB = (T + BLOCK_T - 1) // BLOCK_T
    assert block_table.ndim == 1, f"block_table must be 1D, got {block_table.shape}"
    assert block_table.numel() >= LB, f"block_table too short: need {LB}, got {block_table.numel()}"
    assert block_table.dtype in (torch.int32, torch.int64)

    # Build logical token indices [0..T-1]
    device = Q.device
    t_idx = torch.arange(T, device=device, dtype=torch.int64)          # [T]
    lb = t_idx // BLOCK_T                                              # [T]
    off = t_idx % BLOCK_T                                              # [T]
    pb = block_table[lb].to(torch.int64)                               # [T]


    # Gather K,V into contiguous [T,D]
    # Advanced indexing: K_pool[pb, off] -> [T,D]
    K_contig = K_pool[pb, off, :]                                      # [T,D]
    V_contig = V_pool[pb, off, :]                                      # [T,D]


    return attention_torch_contiguous(Q, K_contig, V_contig, mask=mask)

# ============================================================
# Driver: build toy data, run correctness checks
# ============================================================
@torch.no_grad()
def check_correctness(
    T: int = 1024,
    D: int = 64,
    block_T: int = 16,
    num_phys_blocks: int = 128,
    M: int = 1,
    dtype=torch.float16,
    use_mask: bool = False,
):
    device = "cuda"
    torch.manual_seed(0)

    cfg = PageCfg(BLOCK_T=block_T, D=D, NUM_PHYS_BLOCKS=num_phys_blocks)

    # Create a toy query (M queries)
    Q = torch.randn((M, D), device=device, dtype=dtype)

    # Create logical tokens for KV (as if appended over time)
    K_tokens = torch.randn((T, D), device=device, dtype=dtype)
    V_tokens = torch.randn((T, D), device=device, dtype=dtype)

    # --- Contiguous baseline ---
    K_contig, V_contig = alloc_contiguous_kv(T, D, device=device, dtype=dtype)
    K_contig.copy_(K_tokens)
    V_contig.copy_(V_tokens)

    mask = None
    if use_mask:
        # TODO: define mask shape; for toy use [M, T] or [M, T] additive
        # or full [M, T] if you compute scores row-wise.
        # mask = ...
        # raise NotImplementedError("TODO: mask construction")
        keep = T // 2
        mask = torch.zeros((M, T), device=device, dtype=torch.float32)
        mask[:, keep:] = float("-inf")

    # TODO: torch contiguous reference
    # out_ref = attention_torch_contiguous(Q, K_contig, V_contig, mask=mask)
    out_ref = attention_torch_contiguous(Q, K_contig, V_contig, mask=mask)

    # --- Paged layout ---
    K_pool, V_pool = alloc_paged_kv_pool(num_phys_blocks, block_T, D, device=device, dtype=dtype)
    LB = (T + block_T - 1) // block_T
    block_table = build_block_table(LB, num_phys_blocks, device=device)

    write_tokens_to_paged_kv(K_tokens, V_tokens, K_pool, V_pool, block_table, block_T)

    # TODO: Triton paged attention
    # out_paged = page_attention_triton(Q, K_pool, V_pool, block_table, T, cfg, mask=mask)
    out_paged = page_attention_triton(Q, K_pool, V_pool, block_table, T, cfg, mask=mask)

    # TODO: compare out_paged with out_ref
    # max_err = (out_paged - out_ref).abs().max().item()
    # mean_err = (out_paged - out_ref).abs().mean().item()
    # print(...)
    # raise NotImplementedError
    diff = (out_paged - out_ref).float()
    max_err = diff.abs().max().item()
    mean_err = diff.abs().mean().item()

    print("\n=== Correctness Check ===")
    print(f"T={T}, D={D}, block_T={block_T}, PB={num_phys_blocks}, M={M}, dtype={dtype}, use_mask={use_mask}")
    print(f"max_abs_err  = {max_err:.6e}")
    print(f"mean_abs_err = {mean_err:.6e}")

    # Simple tolerance guidance (toy fp16): adjust if needed
    tol = 5e-2 if dtype in (torch.float16, torch.bfloat16) else 1e-4
    if max_err > tol:
        print(f"[WARN] max_err {max_err:.3e} > tol {tol:.3e} (check paging addr / mask / numerics)")
    else:
        print("[OK] Within tolerance.")

    return {
        "max_abs_err": max_err,
        "mean_abs_err": mean_err,
        "out_ref": out_ref,
        "out_paged": out_paged,
        "mask": mask,
        "block_table": block_table,
    }


def page_attention_triton(
    Q: torch.Tensor,
    K_pool: torch.Tensor,
    V_pool: torch.Tensor,
    block_table: torch.Tensor,
    T: int,
    cfg: PageCfg,
    mask: Optional[torch.Tensor] = None,
    BLOCK_N: int = 128,
    num_warps: int = 4,
) -> torch.Tensor:
    """
    Launcher for page_attention_kernel.
    Q: [M, D]
    K_pool/V_pool: [PB, BLOCK_T, D]
    block_table: [LB]
    Returns:
      out: [M, D]
    """
    # TODO:
    # - Validate shapes/dtypes
    # - Allocate out
    # - Define grid (e.g., one program per query row m)
    # - Pass strides and constexpr args
    # - HAS_MASK toggle
    # raise NotImplementedError
    assert Q.ndim == 2
    M, D = Q.shape
    assert K_pool.ndim == 3 and V_pool.ndim == 3
    PB, BT, Dp = K_pool.shape
    assert (PB, BT, Dp) == V_pool.shape
    assert BT == cfg.BLOCK_T and Dp == D
    LB = (T + cfg.BLOCK_T - 1) // cfg.BLOCK_T
    assert block_table.ndim == 1 and block_table.numel() >= LB
    assert block_table.dtype in (torch.int32, torch.int64)

    out = torch.empty((M, D), device=Q.device, dtype=Q.dtype)

    HAS_MASK = mask is not None
    if HAS_MASK:
        assert mask.shape == (M, T), f"mask must be [M,T], got {mask.shape}"
        # mask can be fp16/fp32; kernel reads as fp32
        mask_ptr = mask
        stride_mm, stride_mt = mask.stride()
    else:
        # dummy tensor (won't be read when HAS_MASK=False)
        mask_ptr = out  # any valid pointer on device
        stride_mm, stride_mt = 0, 0

    grid = (M,)
    page_attention_kernel[grid](
        Q, K_pool, V_pool, block_table, mask_ptr, out,
        M=M, T=T, D=D, BLOCK_T=cfg.BLOCK_T,
        stride_qm=Q.stride(0), stride_qd=Q.stride(1),
        stride_kpb=K_pool.stride(0), stride_kpt=K_pool.stride(1), stride_kd=K_pool.stride(2),
        stride_vpb=V_pool.stride(0), stride_vpt=V_pool.stride(1), stride_vd=V_pool.stride(2),
        stride_mm=stride_mm, stride_mt=stride_mt,
        stride_om=out.stride(0), stride_od=out.stride(1),
        BLOCK_N=BLOCK_N,
        HAS_MASK=HAS_MASK,
        num_warps=num_warps,
    )

    return out

# ============================================================
# Memory statistics printing
# ============================================================
def print_mem_stats(T: int, D: int, cfg: PageCfg, dtype=torch.float16):
    c = mem_stats_contiguous(T, D, dtype=dtype)
    p = mem_stats_paged(T, cfg, dtype=dtype)

    print("\n=== Memory Stats ===")
    print(f"Contiguous KV: allocated={c.kv_bytes_allocated/1e6:.3f} MB, used={c.kv_bytes_used/1e6:.3f} MB")
    print(f"Paged KV     : allocated={p.kv_bytes_allocated/1e6:.3f} MB, used={p.kv_bytes_used/1e6:.3f} MB, "
          f"frag={p.fragmentation_ratio*100:.2f}%")
    print("====================\n")


# ============================================================
# Main
# ============================================================
def main():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA required.")

    # TODO: adjust toy parameters
    T = 1024
    D = 64
    block_T = 16
    num_phys_blocks = 128
    M = 1

    cfg = PageCfg(BLOCK_T=block_T, D=D, NUM_PHYS_BLOCKS=num_phys_blocks)
    print_mem_stats(T, D, cfg, dtype=torch.float16)

    # TODO: run correctness
    check_correctness(T=T, D=D, block_T=block_T, num_phys_blocks=num_phys_blocks, M=M, use_mask=False)
    # raise NotImplementedError("TODO: wire up correctness once kernels are implemented")


if __name__ == "__main__":
    main()



=== Memory Stats ===
Contiguous KV: allocated=0.262 MB, used=0.262 MB
Paged KV     : allocated=0.524 MB, used=0.262 MB, frag=50.00%


=== Correctness Check ===
T=1024, D=64, block_T=16, PB=128, M=1, dtype=torch.float16, use_mask=False
max_abs_err  = 0.000000e+00
mean_abs_err = 0.000000e+00
[OK] Within tolerance.


In [14]:
# day6_flashattention_mini_skeleton.py
# ============================================================
# Day 6 ‚Äî Triton FlashAttention (Mini) (NO SOLUTION)
#
# Goal:
#   Implement a mini FlashAttention-style kernel in Triton:
#     - tiled QK^T
#     - online softmax (streaming max/sum)
#     - fuse V multiplication
#     - single kernel end-to-end
#   Then:
#     - validate correctness vs PyTorch
#     - compare performance vs naive attention (Day2)
#
# Scope (toy but realistic):
#   - Single head (extend later)
#   - One batch (extend later)
#   - Q,K,V: [N, D] (N = seq length, D = head dim)
#   - Output O: [N, D]
#   - Optional causal mask (recommended)
#   - Inputs fp16/bf16, accumulate fp32
#
# Notes:
#   - This is a skeleton with TODOs only (no solution).
#   - You will need to choose tiling sizes that fit SRAM (shared memory/registers).
# ============================================================

from __future__ import annotations
import math
from dataclasses import dataclass, asdict
from typing import Optional, Dict, Any, Tuple, List

import torch
import triton
import triton.language as tl


# ============================================================
# TODO: import your Day2 naive attention for comparison
# ============================================================
def naive_attention_triton(Q, K, V, mask=None, cfg=None):
    # TODO: import and call your Day2 implementation
    _assert_cuda(Q, "Q")
    _assert_cuda(K, "K")
    _assert_cuda(V, "V")
    assert Q.shape == K.shape, "Q and K must have shape [N, D]"
    assert Q.shape == V.shape, "For this toy naive version, assume V has shape [N, D]"
    N, D = Q.shape

    if mask is not None:
        _assert_cuda(mask, "mask")
        assert mask.shape == (N, N), "mask must be [N, N] additive mask (0 / -inf)"

    # 1) Scores: S = Q @ K^T   -> [N, N] (often fp32)
    S = qk_t_triton(Q, K)

    # 2) Probabilities: P = softmax(S + mask)  -> [N, N]
    P = softmax_triton(S, mask)

    # 3) Output: O = P @ V     -> [N, D]
    O = pv_triton(P, V)

    return O


# ============================================================
# Config
# ============================================================
@dataclass(frozen=True)
class FlashCfg:
    BLOCK_M: int     # rows of Q processed per program
    BLOCK_N: int     # cols of K/V per step (streaming over N)
    BLOCK_D: int     # head dim tile (usually == D, but keep generic)
    num_warps: int = 4
    num_stages: int = 2  # optional pipelining


# ============================================================
# Mask helper (optional)
# ============================================================
def make_additive_causal_mask(N: int, device="cuda", dtype=torch.float32) -> torch.Tensor:
    """
    Returns additive causal mask [N, N]:
      0 for j <= i, -inf for j > i
    Used as: scores = scores + mask
    """
    # TODO: implement
    # raise NotImplementedError


# ============================================================
# FlashAttention mini kernel (single kernel)
# ============================================================
@triton.jit
def flashattn_mini_kernel(
    q_ptr, k_ptr, v_ptr, o_ptr,
    # optional mask pointer (additive), can be None by HAS_MASK flag
    mask_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_qn: tl.constexpr, stride_qd: tl.constexpr,
    stride_kn: tl.constexpr, stride_kd: tl.constexpr,
    stride_vn: tl.constexpr, stride_vd: tl.constexpr,
    stride_on: tl.constexpr, stride_od: tl.constexpr,
    # mask strides (if used): mask is [N, N] additive
    stride_mn: tl.constexpr, stride_mm: tl.constexpr,
    # tiling
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
    HAS_MASK: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    # scale (typically 1/sqrt(D))
    SCALE: tl.constexpr,
):
    """
    Compute O = softmax(QK^T + mask) V using tiling + online softmax, fused with V.

    Structure (conceptual):
      For each block of queries (m tile):
        - initialize:
            m_i = -inf         # running max per query row
            l_i = 0            # running sum(exp(scores - m_i))
            acc = 0            # running output accumulator (fp32)
        - for n_tile over keys/values:
            scores = q_tile @ k_tile^T * SCALE + mask_tile
            # online softmax update:
            m_new = max(m_i, rowmax(scores))
            alpha = exp(m_i - m_new)
            p = exp(scores - m_new)
            l_new = l_i * alpha + rowsum(p)
            acc = acc * alpha[:,None] + p @ v_tile
            m_i = m_new
            l_i = l_new
        - normalize:
            out = acc / l_i[:,None]
        - store out

    TODOs:
      - Map program_id to query block start
      - Load Q tile [BLOCK_M, D] (or [BLOCK_M, BLOCK_D] with loop if needed)
      - Loop over K/V tiles along N:
          * Load K tile [BLOCK_N, D]
          * Compute score tile [BLOCK_M, BLOCK_N] in fp32
          * Apply causal masking if IS_CAUSAL (score for j>i = -inf)
          * Apply additive mask if HAS_MASK (mask_ptr)
          * Update online softmax stats (m_i, l_i)
          * Fuse V multiplication: acc += p @ V_tile
      - Final normalize acc by l_i
      - Store O tile
    """
    # TODO: implement
    # raise NotImplementedError
    #----------------Determine which query rows this block owns-----------------
    pid = tl.program_id(axis=0)
    m_start = pid * BLOCK_M

    row = m_start + tl.arange(0, BLOCK_M)
    row_mask = row < N
    d = tl.arange(0, D) #[D]
    d_mask   = d < D
    q_mask   = row_mask[:, None] & d_mask[None, :]

    #----------------Load a tile of Q-----------------
    d = tl.arange(0, D)                      # [D]
    q_ptrs = q_ptr + row[:, None] * stride_qn + d[None, :] * stride_qd
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)# [BM, D]

    #----------------Initialize online softmax state-----------------
    m = tl.full([BLOCK_M], -float("inf"), tl.float32)
    l = tl.zeros([BLOCK_M], tl.float32)
    acc = tl.zeros([BLOCK_M, D], tl.float32)

    #----------------Main Loop: Iterate over K/V tiles-----------------
    for n_start in range(0, N, BLOCK_N):
          # indices of K/V rows for this tile
          col = n_start + tl.arange(0, BLOCK_N)# [BN]
          col_mask = col < N

          # load K tile [BLOCK_N, D]
          k_ptrs = k_ptr + col[:, None] * stride_kn + d[None, :] * stride_kd  # [BN, D]
          k_mask = col_mask[:, None] & d_mask[None, :]
          k_tile = tl.load(k_ptrs, mask=k_mask, other=0.0)

          # load V tile [BLOCK_N, D]
          v_ptrs = v_ptr + col[:, None] * stride_vn + d[None, :] * stride_vd  # [BN, D]
          v_mask = col_mask[:, None] & d_mask[None, :]
          v_tile = tl.load(v_ptrs, mask=v_mask, other=0.0)

          # compute score tile [BLOCK_M, BLOCK_N]
          scores = tl.dot(q, tl.trans(k_tile)) * SCALE

          if IS_CAUSAL:
              # apply causal masking if IS_CAUSAL (score for j>i = -inf)
              row_mask = row < N              # [BLOCK_M]
              col_mask = col < N              # [BLOCK_N]

              in_bounds = row_mask[:, None] & col_mask[None, :]  # [BLOCK_M, BLOCK_N]
              causal_keep = col[None, :] <= row[:, None]         # [BLOCK_M, BLOCK_N]

              keep = in_bounds & causal_keep
              scores = tl.where(keep, scores, -float("inf"))

          if HAS_MASK:
              # apply additive mask if HAS_MASK (mask_ptr)
              mask_ptrs = mask_ptr + row[:, None] * stride_mn + col[None, :] * stride_mm
              mask_load_mask = (row < N)[:, None] & (col < N)[None, :]
              mask_tile = tl.load(mask_ptrs, mask=mask_load_mask, other=0.0).to(tl.float32)
              scores = scores + mask_tile

          # Online Softmax Update
          # 1) row-wise max on this tile
          s_max = tl.max(scores, axis=1)# [BLOCK_M]

          # 2) new running max per row
          m_new = tl.maximum(m, s_max)# [BLOCK_M]

          # 3) rescale factor to bring old accumulators into the new max "gauge"
          alpha = tl.exp(m - m_new)# [BLOCK_M]

          # 4) exponentiate current tile scores using the new max (broadcast m_new across columns)
          p = tl.exp(scores - m_new[:, None]) # [BLOCK_M, BLOCK_N]

          # 5) update running sum per row
          l_new = l * alpha + tl.sum(p, axis=1)        # [BLOCK_M]

          # 6) update output accumulator (fused P @ V)
          #    p is fp32; accumulate in fp32 for stability
          acc = acc * alpha[:, None] + tl.dot(p, v_tile.to(tl.float32)) # [BLOCK_M, D]

          # 7) commit
          m = m_new
          l = l_new


    out = acc / l[:, None]   # [BLOCK_M, D]
    d = tl.arange(0, D)  # [D]
    o_ptrs = o_ptr + row[:, None] * stride_on + d[None, :] * stride_od   # [BLOCK_M, D]
    row_mask = row < N
    o_mask = row_mask[:, None]   # [BLOCK_M, 1] broadcast to [BLOCK_M, D]
    tl.store(o_ptrs, out.to(tl.float16), mask=o_mask)


# ============================================================
# Launcher
# ============================================================
def _assert_2d(x: torch.Tensor, name: str) -> None:
    """Assert tensor is 2D [N, D]."""
    if not isinstance(x, torch.Tensor):
        raise TypeError(f"{name} must be a torch.Tensor, got {type(x)}")
    if x.ndim != 2:
        raise ValueError(f"{name} must be 2D [N, D], got shape {tuple(x.shape)}")


def _assert_cuda_contig(x: torch.Tensor, name: str) -> None:
    """Assert tensor is on CUDA and contiguous."""
    if not x.is_cuda:
        raise ValueError(f"{name} must be on CUDA, got {x.device}")
    if not x.is_contiguous():
        raise ValueError(f"{name} must be contiguous, got strides {x.stride()} and shape {tuple(x.shape)}")
def flashattn_mini_triton(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    causal: bool = False,
    cfg: FlashCfg = FlashCfg(BLOCK_M=32, BLOCK_N=64, BLOCK_D=64, num_warps=4),
) -> torch.Tensor:
    """
    Q,K,V: [N, D] contiguous CUDA tensors
    mask: additive [N, N] or None
    causal: if True, apply causal mask internally
    """

    # -----------------------------
    # Validate Q/K/V
    # -----------------------------
    _assert_2d(Q, "Q")
    _assert_2d(K, "K")
    _assert_2d(V, "V")

    _assert_cuda_contig(Q, "Q")
    _assert_cuda_contig(K, "K")
    _assert_cuda_contig(V, "V")

    N, D = Q.shape
    if K.shape != (N, D):
        raise ValueError(f"K must have shape {(N, D)}, got {tuple(K.shape)}")
    if V.shape != (N, D):
        raise ValueError(f"V must have shape {(N, D)}, got {tuple(V.shape)}")

    if Q.dtype not in (torch.float16, torch.bfloat16, torch.float32):
        raise ValueError(f"Q dtype must be fp16/bf16/fp32, got {Q.dtype}")
    if K.dtype != Q.dtype or V.dtype != Q.dtype:
        raise ValueError(f"K and V must have same dtype as Q (Q={Q.dtype}, K={K.dtype}, V={V.dtype})")

    # Optional but recommended: enforce same device
    if K.device != Q.device or V.device != Q.device:
        raise ValueError(f"Q/K/V must be on the same device. Q={Q.device}, K={K.device}, V={V.device}")

    # -----------------------------
    # Validate cfg (mini kernel assumption)
    # -----------------------------
    if cfg.BLOCK_M <= 0 or cfg.BLOCK_N <= 0:
        raise ValueError("cfg.BLOCK_M and cfg.BLOCK_N must be positive")
    # This mini kernel (as written in our discussion) loads full D directly.
    # If your kernel actually loops over D with BLOCK_D, you can relax this.
    if cfg.BLOCK_D != D:
        raise ValueError(f"This mini wrapper expects cfg.BLOCK_D == D. Got BLOCK_D={cfg.BLOCK_D}, D={D}")

    # -----------------------------
    # Mask handling
    # -----------------------------
    has_mask = mask is not None
    is_causal = bool(causal)

    if has_mask:
        if mask.ndim != 2 or mask.shape != (N, N):
            raise ValueError(f"mask must be [N, N] = {(N, N)}, got {tuple(mask.shape)}")
        if not mask.is_cuda:
            raise ValueError(f"mask must be on CUDA, got {mask.device}")
        # For this mini wrapper, keep it simple: contiguous mask
        if not mask.is_contiguous():
            raise ValueError("mask must be contiguous for this mini wrapper")
        if mask.dtype not in (torch.float16, torch.bfloat16, torch.float32):
            raise ValueError(f"mask dtype must be fp16/bf16/fp32, got {mask.dtype}")

        mask_ptr = mask
        stride_mn, stride_mm = mask.stride()
    else:
        # Kernel ignores mask_ptr when HAS_MASK=False; pass any valid pointer.
        mask_ptr = Q
        stride_mn, stride_mm = 0, 0

    # -----------------------------
    # Allocate output
    # -----------------------------
    out = torch.empty((N, D), device=Q.device, dtype=Q.dtype)

    # -----------------------------
    # Strides (in elements)
    # -----------------------------
    stride_qn, stride_qd = Q.stride()
    stride_kn, stride_kd = K.stride()
    stride_vn, stride_vd = V.stride()
    stride_on, stride_od = out.stride()

    # -----------------------------
    # Launch params
    # -----------------------------
    grid = (triton.cdiv(N, cfg.BLOCK_M),)
    scale = 1.0 / math.sqrt(D)

    # -----------------------------
    # Call kernel
    # -----------------------------
    flashattn_mini_kernel[grid](
        Q, K, V, out,
        mask_ptr,
        N=N, D=D,
        stride_qn=stride_qn, stride_qd=stride_qd,
        stride_kn=stride_kn, stride_kd=stride_kd,
        stride_vn=stride_vn, stride_vd=stride_vd,
        stride_on=stride_on, stride_od=stride_od,
        stride_mn=stride_mn, stride_mm=stride_mm,
        BLOCK_M=cfg.BLOCK_M,
        BLOCK_N=cfg.BLOCK_N,
        BLOCK_D=cfg.BLOCK_D,
        HAS_MASK=has_mask,
        IS_CAUSAL=is_causal,
        SCALE=scale,
        num_warps=cfg.num_warps,
        num_stages=1
    )

    return out

# ============================================================
# PyTorch reference + correctness
# ============================================================
def flashattn_ref_torch(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    causal: bool = False,
):
    """
    Reference attention in torch:
      scores = Q @ K.T / sqrt(D)
      if causal: apply causal mask
      if mask: scores += mask
      P = softmax(scores)
      O = P @ V
    """
    # TODO: implement reference (use fp32 for scores/softmax for stability)
    # Notes:
    #   - scores/softmax computed in fp32 for stability
    #   - output returned in same dtype as Q

    assert Q.ndim == 2 and K.ndim == 2 and V.ndim == 2
    N, D = Q.shape
    assert K.shape == (N, D) and V.shape == (N, D)

    # fp32 scores for stability
    q = Q.float()
    k = K.float()
    v = V.float()

    scores = q @ k.transpose(0, 1)  # [N, N]
    scores *= (1.0 / math.sqrt(D))

    if causal:
        # Upper triangle (j > i) set to -inf
        # Use a bool mask and masked_fill for clarity.
        causal_mask = torch.triu(torch.ones((N, N), device=Q.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(causal_mask, float("-inf"))

    if mask is not None:
        assert mask.shape == (N, N), f"mask must be [N, N], got {tuple(mask.shape)}"
        scores = scores + mask.float()

    P = torch.softmax(scores, dim=-1)   # [N, N] fp32
    O = P @ v                           # [N, D] fp32
    return O.to(dtype=Q.dtype)


@torch.no_grad()
def check_correctness(
    N=1024, D=64, dtype=torch.float16,
    use_mask=False, causal=True,
    cfg: FlashCfg = FlashCfg(BLOCK_M=64, BLOCK_N=64, BLOCK_D=64, num_warps=4),
):
    device = "cuda"
    torch.manual_seed(0)

    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        # TODO: create additive mask (e.g., padding or random -inf positions)
        drop_prob = 0.05
        drop = (torch.rand((N, N), device=device) < drop_prob)
        mask = torch.zeros((N, N), device=device, dtype=torch.float32)
        mask = mask.masked_fill(drop, float("-inf"))

        # Optional: don't mask diagonal to avoid degenerate rows
        diag = torch.eye(N, device=device, dtype=torch.bool)
        mask = mask.masked_fill(diag, 0.0)

    # TODO:
    # - out_ref = flashattn_ref_torch(...)
    # - out_tri = flashattn_mini_triton(...)
    # - print max/mean abs error
    out_ref = flashattn_ref_torch(Q, K, V, mask=mask, causal=causal)
    out_tri = flashattn_mini_triton(Q, K, V, mask=mask, causal=causal, cfg=cfg)

    # Compare in fp32 for reporting
    diff = (out_ref.float() - out_tri.float()).abs()
    max_abs = diff.max().item()
    mean_abs = diff.mean().item()

    print(f"N={N}, D={D}, dtype={dtype}, use_mask={use_mask}, causal={causal}")
    print(f"max_abs_err = {max_abs:.3e}")
    print(f"mean_abs_err = {mean_abs:.3e}")

    return max_abs, mean_abs


# ============================================================
# Benchmark: compare naive vs flash
# ============================================================
def cuda_time_ms(fn, iters=30, warmup=10) -> float:
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)

    s.record()
    for _ in range(iters):
        fn()
    e.record()
    torch.cuda.synchronize()
    return s.elapsed_time(e) / iters


@torch.no_grad()
def compare_perf(
    N_list: List[int] = [256, 512, 1024, 2048],
    D: int = 64,
    dtype=torch.float16,
    causal: bool = True,
    cfg_flash: FlashCfg = FlashCfg(BLOCK_M=64, BLOCK_N=64, BLOCK_D=64, num_warps=4),
    cfg_naive: Optional[Dict[str, Any]] = None,
):
    device = "cuda"
    torch.manual_seed(0)

    print("| N | Impl | ms | speedup_vs_naive |")
    print("|---|------|----|------------------|")

    for N in N_list:
        Q = torch.randn((N, D), device=device, dtype=dtype)
        K = torch.randn((N, D), device=device, dtype=dtype)
        V = torch.randn((N, D), device=device, dtype=dtype)

        mask = None
        if causal:
            # For naive attention you may need a materialized mask; for flash you might do internal causal.
            # TODO: create mask for naive if required by your implementation.
            pass

        # --- naive ---
        def fn_naive():
            return naive_attention_triton(Q, K, V, mask=mask, cfg=cfg_naive)

        # --- flash ---
        def fn_flash():
            return flashattn_mini_triton(Q, K, V, mask=None, causal=causal, cfg=cfg_flash)

        # TODO: optionally benchmark torch reference too
        ms_naive = cuda_time_ms(fn_naive)
        ms_flash = cuda_time_ms(fn_flash)

        speedup = ms_naive / ms_flash if ms_flash > 0 else float("inf")

        print(f"| {N} | naive | {ms_naive:.4f} | 1.00x |")
        print(f"| {N} | flash | {ms_flash:.4f} | {speedup:.2f}x |")


# ============================================================
# Main
# ============================================================
def main():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA required.")

    # TODO: run correctness first on small N
    check_correctness(N=256, D=64, causal=True, use_mask=False)

    # TODO: then benchmark scaling
    compare_perf(N_list=[256, 512, 1024, 1024], D=64)

    # raise NotImplementedError("TODO: wire up your kernels and run correctness/bench")


if __name__ == "__main__":
    main()


N=256, D=64, dtype=torch.float16, use_mask=False, causal=True
max_abs_err = 2.441e-04
mean_abs_err = 4.479e-08
| N | Impl | ms | speedup_vs_naive |
|---|------|----|------------------|
| 256 | naive | 1.0282 | 1.00x |
| 256 | flash | 0.2426 | 4.24x |
| 512 | naive | 1.7125 | 1.00x |
| 512 | flash | 0.4714 | 3.63x |
| 1024 | naive | 3.7537 | 1.00x |
| 1024 | flash | 0.5541 | 6.77x |
| 1024 | naive | 3.5619 | 1.00x |
| 1024 | flash | 0.4618 | 7.71x |


In [17]:
# day6_flashattention_v1_vs_v2_splitk.py
# ============================================================
# FlashAttention v1 (single kernel) vs v2-style split-K (2 kernels)
# - correctness vs torch
# - perf benchmark (v1 vs v2)
#
# Notes for Tesla T4:
# - T4 shared memory per block limit: 64KB
# - Keep BLOCK_N modest (e.g., 64) and num_stages=1 first.
# ============================================================

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, List

import torch
import triton
import triton.language as tl


# ============================================================
# Utilities
# ============================================================
def _assert_2d(x: torch.Tensor, name: str) -> None:
    if not isinstance(x, torch.Tensor):
        raise TypeError(f"{name} must be a torch.Tensor, got {type(x)}")
    if x.ndim != 2:
        raise ValueError(f"{name} must be 2D [N, D], got shape {tuple(x.shape)}")


def _assert_cuda_contig(x: torch.Tensor, name: str) -> None:
    if not x.is_cuda:
        raise ValueError(f"{name} must be on CUDA, got {x.device}")
    if not x.is_contiguous():
        raise ValueError(f"{name} must be contiguous, got strides {x.stride()} and shape {tuple(x.shape)}")


def _out_tl_dtype_from_torch(dtype: torch.dtype):
    # Map torch dtype -> triton tl dtype (constexpr friendly)
    if dtype == torch.float16:
        return tl.float16
    if dtype == torch.bfloat16:
        return tl.bfloat16
    if dtype == torch.float32:
        return tl.float32
    raise ValueError(f"Unsupported dtype {dtype}")


def cuda_time_ms(fn, iters=30, warmup=10) -> float:
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    s.record()
    for _ in range(iters):
        fn()
    e.record()
    torch.cuda.synchronize()
    return s.elapsed_time(e) / iters


# ============================================================
# v1 Config
# ============================================================
@dataclass(frozen=True)
class FlashV1Cfg:
    BLOCK_M: int = 64
    BLOCK_N: int = 64   # T4-safe starter
    BLOCK_D: int = 64
    num_warps: int = 4
    num_stages: int = 1  # T4-safe starter


# ============================================================
# v1 Kernel (single kernel)
# ============================================================
@triton.jit
def flashattn_v1_kernel(
    q_ptr, k_ptr, v_ptr, o_ptr,
    mask_ptr,
    N: tl.constexpr, D: tl.constexpr,
    stride_qn: tl.constexpr, stride_qd: tl.constexpr,
    stride_kn: tl.constexpr, stride_kd: tl.constexpr,
    stride_vn: tl.constexpr, stride_vd: tl.constexpr,
    stride_on: tl.constexpr, stride_od: tl.constexpr,
    stride_mn: tl.constexpr, stride_mm: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
    HAS_MASK: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    SCALE: tl.constexpr,
    OUT_DTYPE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    m_start = pid * BLOCK_M

    row = m_start + tl.arange(0, BLOCK_M)                  # [BM]
    row_mask = row < N

    d = tl.arange(0, D)                                    # [D]
    d_mask = d < D
    q_ptrs = q_ptr + row[:, None] * stride_qn + d[None, :] * stride_qd
    q = tl.load(q_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BM, D]

    m = tl.full([BLOCK_M], -float("inf"), tl.float32)      # [BM]
    l = tl.zeros([BLOCK_M], tl.float32)                    # [BM]
    acc = tl.zeros([BLOCK_M, D], tl.float32)               # [BM, D]

    # stream over all keys
    for n_start in range(0, N, BLOCK_N):
        col = n_start + tl.arange(0, BLOCK_N)              # [BN]
        col_mask = col < N

        k_ptrs = k_ptr + col[:, None] * stride_kn + d[None, :] * stride_kd
        v_ptrs = v_ptr + col[:, None] * stride_vn + d[None, :] * stride_vd
        k = tl.load(k_ptrs, mask=col_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BN, D]
        v = tl.load(v_ptrs, mask=col_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BN, D]

        scores = tl.dot(q, tl.trans(k)) * SCALE            # [BM, BN], fp32

        # causal mask: keep only col <= row (global indices!)
        if IS_CAUSAL:
            keep = (col[None, :] <= row[:, None]) & (row_mask[:, None] & col_mask[None, :])
            scores = tl.where(keep, scores, -float("inf"))
        else:
            # still ensure OOB columns don't contribute
            scores = tl.where(row_mask[:, None] & col_mask[None, :], scores, -float("inf"))

        if HAS_MASK:
            mptrs = mask_ptr + row[:, None] * stride_mn + col[None, :] * stride_mm
            mvals = tl.load(mptrs, mask=row_mask[:, None] & col_mask[None, :], other=0.0).to(tl.float32)
            scores = scores + mvals

        s_max = tl.max(scores, axis=1)                     # [BM]
        m_new = tl.maximum(m, s_max)                       # [BM]
        alpha = tl.exp(m - m_new)                          # [BM]
        p = tl.exp(scores - m_new[:, None])                # [BM, BN]

        l_new = l * alpha + tl.sum(p, axis=1)              # [BM]
        acc = acc * alpha[:, None] + tl.dot(p, v)          # [BM, D]

        m = m_new
        l = l_new

    out = acc / l[:, None]                                 # [BM, D]
    o_ptrs = o_ptr + row[:, None] * stride_on + d[None, :] * stride_od
    tl.store(o_ptrs, out.to(OUT_DTYPE), mask=row_mask[:, None] & d_mask[None, :])


def flashattn_v1_triton(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    causal: bool = False,
    cfg: FlashV1Cfg = FlashV1Cfg(),
) -> torch.Tensor:
    _assert_2d(Q, "Q"); _assert_2d(K, "K"); _assert_2d(V, "V")
    _assert_cuda_contig(Q, "Q"); _assert_cuda_contig(K, "K"); _assert_cuda_contig(V, "V")
    N, D = Q.shape
    if K.shape != (N, D) or V.shape != (N, D):
        raise ValueError("K,V must match Q shape")
    if cfg.BLOCK_D != D:
        raise ValueError(f"Mini assumes BLOCK_D == D. Got {cfg.BLOCK_D} vs D={D}")
    if Q.dtype not in (torch.float16, torch.bfloat16, torch.float32):
        raise ValueError("Q dtype must be fp16/bf16/fp32")
    if K.dtype != Q.dtype or V.dtype != Q.dtype:
        raise ValueError("K and V must match Q dtype")

    has_mask = mask is not None
    if has_mask:
        if mask.shape != (N, N) or (not mask.is_cuda) or (not mask.is_contiguous()):
            raise ValueError("mask must be CUDA contiguous and shape [N,N]")
        stride_mn, stride_mm = mask.stride()
        mask_ptr = mask
    else:
        mask_ptr = Q
        stride_mn, stride_mm = 0, 0

    out = torch.empty((N, D), device=Q.device, dtype=Q.dtype)

    stride_qn, stride_qd = Q.stride()
    stride_kn, stride_kd = K.stride()
    stride_vn, stride_vd = V.stride()
    stride_on, stride_od = out.stride()

    grid = (triton.cdiv(N, cfg.BLOCK_M),)
    scale = 1.0 / math.sqrt(D)
    out_tl_dtype = _out_tl_dtype_from_torch(Q.dtype)

    flashattn_v1_kernel[grid](
        Q, K, V, out,
        mask_ptr,
        N=N, D=D,
        stride_qn=stride_qn, stride_qd=stride_qd,
        stride_kn=stride_kn, stride_kd=stride_kd,
        stride_vn=stride_vn, stride_vd=stride_vd,
        stride_on=stride_on, stride_od=stride_od,
        stride_mn=stride_mn, stride_mm=stride_mm,
        BLOCK_M=cfg.BLOCK_M, BLOCK_N=cfg.BLOCK_N, BLOCK_D=cfg.BLOCK_D,
        HAS_MASK=has_mask, IS_CAUSAL=bool(causal),
        SCALE=scale,
        OUT_DTYPE=out_tl_dtype,
        num_warps=cfg.num_warps,
        num_stages=cfg.num_stages,
    )
    return out


# ============================================================
# v2 Config (split-K)
# ============================================================
@dataclass(frozen=True)
class FlashV2Cfg:
    BLOCK_M: int = 64
    BLOCK_N: int = 64      # T4-safe starter
    BLOCK_D: int = 64
    num_splits: int = 4    # split-K parallelism
    num_warps: int = 4
    num_stages: int = 1    # T4-safe starter


# ============================================================
# v2 Stage 1: (Q-block, split) -> partial (m_s, l_s, acc_s)
# ============================================================
@triton.jit
def flashattn_v2_stage1_kernel(
    q_ptr, k_ptr, v_ptr,
    mask_ptr,
    m_partial_ptr,
    l_partial_ptr,
    acc_partial_ptr,
    N: tl.constexpr, D: tl.constexpr,

    stride_qn: tl.constexpr, stride_qd: tl.constexpr,
    stride_kn: tl.constexpr, stride_kd: tl.constexpr,
    stride_vn: tl.constexpr, stride_vd: tl.constexpr,
    stride_mn: tl.constexpr, stride_mm: tl.constexpr,

    STRIDE_PM_QB: tl.constexpr,
    STRIDE_PM_SPLIT: tl.constexpr,

    STRIDE_PL_QB: tl.constexpr,
    STRIDE_PL_SPLIT: tl.constexpr,

    STRIDE_PACC_QB: tl.constexpr,
    STRIDE_PACC_SPLIT: tl.constexpr,
    STRIDE_PACC_M: tl.constexpr,
    STRIDE_PACC_D: tl.constexpr,

    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_D: tl.constexpr,
    NUM_SPLITS: tl.constexpr,

    HAS_MASK: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    SCALE: tl.constexpr,
):
    pid_qb = tl.program_id(0)   # Q-block id
    pid_s  = tl.program_id(1)   # split id

    m_start = pid_qb * BLOCK_M
    row = m_start + tl.arange(0, BLOCK_M)               # [BM]
    row_mask = row < N

    d = tl.arange(0, D)
    d_mask = d < D

    # split range
    chunk = (N + NUM_SPLITS - 1) // NUM_SPLITS         # constexpr arithmetic
    k_start = pid_s * chunk
    k_end = tl.minimum(N, (pid_s + 1) * chunk)

    # load Q tile
    q_ptrs = q_ptr + row[:, None] * stride_qn + d[None, :] * stride_qd
    q = tl.load(q_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BM, D]

    # online softmax state for this split
    m = tl.full([BLOCK_M], -float("inf"), tl.float32)
    l = tl.zeros([BLOCK_M], tl.float32)
    acc = tl.zeros([BLOCK_M, D], tl.float32)

    # number of tiles within one split chunk (compile-time)
    num_tiles = (chunk + BLOCK_N - 1) // BLOCK_N

    # iterate tiles inside [k_start, k_end)
    for n_start in range(0, N, BLOCK_N):
        col = n_start + tl.arange(0, BLOCK_N)           # [BN]
        col_valid = (col >= k_start) & (col < k_end) & (col < N)

        k_ptrs = k_ptr + col[:, None] * stride_kn + d[None, :] * stride_kd
        v_ptrs = v_ptr + col[:, None] * stride_vn + d[None, :] * stride_vd
        k = tl.load(k_ptrs, mask=col_valid[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BN, D]
        v = tl.load(v_ptrs, mask=col_valid[:, None] & d_mask[None, :], other=0.0).to(tl.float32)  # [BN, D]

        scores = tl.dot(q, tl.trans(k)) * SCALE         # [BM, BN]

        if IS_CAUSAL:
            keep = (col[None, :] <= row[:, None]) & (row_mask[:, None] & col_valid[None, :])
            scores = tl.where(keep, scores, -float("inf"))
        else:
            scores = tl.where(row_mask[:, None] & col_valid[None, :], scores, -float("inf"))

        if HAS_MASK:
            mptrs = mask_ptr + row[:, None] * stride_mn + col[None, :] * stride_mm
            mvals = tl.load(mptrs, mask=row_mask[:, None] & col_valid[None, :], other=0.0).to(tl.float32)
            scores = scores + mvals

        s_max = tl.max(scores, axis=1)
        m_new = tl.maximum(m, s_max)
        alpha = tl.exp(m - m_new)
        p = tl.exp(scores - m_new[:, None])

        l_new = l * alpha + tl.sum(p, axis=1)
        acc = acc * alpha[:, None] + tl.dot(p, v)

        m = m_new
        l = l_new

    # store partials
    offs_m = tl.arange(0, BLOCK_M)

    pm_ptrs = m_partial_ptr + pid_qb * STRIDE_PM_QB + pid_s * STRIDE_PM_SPLIT + offs_m
    pl_ptrs = l_partial_ptr + pid_qb * STRIDE_PL_QB + pid_s * STRIDE_PL_SPLIT + offs_m
    tl.store(pm_ptrs, m, mask=row_mask)
    tl.store(pl_ptrs, l, mask=row_mask)

    pacc_ptrs = (
        acc_partial_ptr
        + pid_qb * STRIDE_PACC_QB
        + pid_s * STRIDE_PACC_SPLIT
        + offs_m[:, None] * STRIDE_PACC_M
        + d[None, :] * STRIDE_PACC_D
    )
    tl.store(pacc_ptrs, acc, mask=row_mask[:, None] & d_mask[None, :])


# ============================================================
# v2 Stage 2: merge splits -> final O
# ============================================================
@triton.jit
def flashattn_v2_stage2_kernel(
    o_ptr,
    m_partial_ptr,
    l_partial_ptr,
    acc_partial_ptr,
    N: tl.constexpr, D: tl.constexpr,

    stride_on: tl.constexpr, stride_od: tl.constexpr,

    STRIDE_PM_QB: tl.constexpr,
    STRIDE_PM_SPLIT: tl.constexpr,

    STRIDE_PL_QB: tl.constexpr,
    STRIDE_PL_SPLIT: tl.constexpr,

    STRIDE_PACC_QB: tl.constexpr,
    STRIDE_PACC_SPLIT: tl.constexpr,
    STRIDE_PACC_M: tl.constexpr,
    STRIDE_PACC_D: tl.constexpr,

    BLOCK_M: tl.constexpr,
    BLOCK_D: tl.constexpr,
    NUM_SPLITS: tl.constexpr,
    OUT_DTYPE: tl.constexpr,
):
    pid_qb = tl.program_id(0)
    m_start = pid_qb * BLOCK_M
    row = m_start + tl.arange(0, BLOCK_M)               # [BM]
    row_mask = row < N

    d = tl.arange(0, D)
    d_mask = d < D

    # 1) merged m = max_s m_s
    m_merged = tl.full([BLOCK_M], -float("inf"), tl.float32)
    offs_m = tl.arange(0, BLOCK_M)
    for s in tl.static_range(0, NUM_SPLITS):
        pm_ptrs = m_partial_ptr + pid_qb * STRIDE_PM_QB + s * STRIDE_PM_SPLIT + offs_m
        m_s = tl.load(pm_ptrs, mask=row_mask, other=-float("inf"))
        m_merged = tl.maximum(m_merged, m_s)

    # 2) merged l and acc
    l_merged = tl.zeros([BLOCK_M], tl.float32)
    acc_merged = tl.zeros([BLOCK_M, D], tl.float32)

    for s in tl.static_range(0, NUM_SPLITS):
        pm_ptrs = m_partial_ptr + pid_qb * STRIDE_PM_QB + s * STRIDE_PM_SPLIT + offs_m
        pl_ptrs = l_partial_ptr + pid_qb * STRIDE_PL_QB + s * STRIDE_PL_SPLIT + offs_m
        m_s = tl.load(pm_ptrs, mask=row_mask, other=-float("inf"))
        l_s = tl.load(pl_ptrs, mask=row_mask, other=0.0)

        w = tl.exp(m_s - m_merged)                      # [BM]
        l_merged += l_s * w

        pacc_ptrs = (
            acc_partial_ptr
            + pid_qb * STRIDE_PACC_QB
            + s * STRIDE_PACC_SPLIT
            + offs_m[:, None] * STRIDE_PACC_M
            + d[None, :] * STRIDE_PACC_D
        )
        acc_s = tl.load(pacc_ptrs, mask=row_mask[:, None] & d_mask[None, :], other=0.0)
        acc_merged += acc_s * w[:, None]

    out = acc_merged / l_merged[:, None]
    o_ptrs = o_ptr + row[:, None] * stride_on + d[None, :] * stride_od
    tl.store(o_ptrs, out.to(OUT_DTYPE), mask=row_mask[:, None] & d_mask[None, :])


def flashattn_v2_splitk_triton(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    causal: bool = False,
    cfg: FlashV2Cfg = FlashV2Cfg(),
) -> torch.Tensor:
    _assert_2d(Q, "Q"); _assert_2d(K, "K"); _assert_2d(V, "V")
    _assert_cuda_contig(Q, "Q"); _assert_cuda_contig(K, "K"); _assert_cuda_contig(V, "V")
    N, D = Q.shape
    if K.shape != (N, D) or V.shape != (N, D):
        raise ValueError("K,V must match Q shape")
    if cfg.BLOCK_D != D:
        raise ValueError(f"Mini assumes BLOCK_D == D. Got {cfg.BLOCK_D} vs D={D}")

    if Q.dtype not in (torch.float16, torch.bfloat16, torch.float32):
        raise ValueError("Q dtype must be fp16/bf16/fp32")
    if K.dtype != Q.dtype or V.dtype != Q.dtype:
        raise ValueError("K and V must match Q dtype")

    has_mask = mask is not None
    if has_mask:
        if mask.shape != (N, N) or (not mask.is_cuda) or (not mask.is_contiguous()):
            raise ValueError("mask must be CUDA contiguous and shape [N,N]")
        stride_mn, stride_mm = mask.stride()
        mask_ptr = mask
    else:
        mask_ptr = Q
        stride_mn, stride_mm = 0, 0

    out = torch.empty((N, D), device=Q.device, dtype=Q.dtype)

    num_qb = triton.cdiv(N, cfg.BLOCK_M)
    num_splits = cfg.num_splits

    # partial buffers in fp32
    m_partial = torch.empty((num_qb, num_splits, cfg.BLOCK_M), device=Q.device, dtype=torch.float32)
    l_partial = torch.empty((num_qb, num_splits, cfg.BLOCK_M), device=Q.device, dtype=torch.float32)
    acc_partial = torch.empty((num_qb, num_splits, cfg.BLOCK_M, D), device=Q.device, dtype=torch.float32)

    # strides (elements)
    STRIDE_PM_QB, STRIDE_PM_SPLIT, _ = m_partial.stride()
    STRIDE_PL_QB, STRIDE_PL_SPLIT, _ = l_partial.stride()
    STRIDE_PACC_QB, STRIDE_PACC_SPLIT, STRIDE_PACC_M, STRIDE_PACC_D = acc_partial.stride()

    stride_qn, stride_qd = Q.stride()
    stride_kn, stride_kd = K.stride()
    stride_vn, stride_vd = V.stride()
    stride_on, stride_od = out.stride()

    scale = 1.0 / math.sqrt(D)
    out_tl_dtype = _out_tl_dtype_from_torch(Q.dtype)

    # Stage 1: 2D grid
    grid1 = (num_qb, num_splits)
    flashattn_v2_stage1_kernel[grid1](
        Q, K, V,
        mask_ptr,
        m_partial, l_partial, acc_partial,
        N=N, D=D,
        stride_qn=stride_qn, stride_qd=stride_qd,
        stride_kn=stride_kn, stride_kd=stride_kd,
        stride_vn=stride_vn, stride_vd=stride_vd,
        stride_mn=stride_mn, stride_mm=stride_mm,
        STRIDE_PM_QB=STRIDE_PM_QB, STRIDE_PM_SPLIT=STRIDE_PM_SPLIT,
        STRIDE_PL_QB=STRIDE_PL_QB, STRIDE_PL_SPLIT=STRIDE_PL_SPLIT,
        STRIDE_PACC_QB=STRIDE_PACC_QB, STRIDE_PACC_SPLIT=STRIDE_PACC_SPLIT,
        STRIDE_PACC_M=STRIDE_PACC_M, STRIDE_PACC_D=STRIDE_PACC_D,
        BLOCK_M=cfg.BLOCK_M, BLOCK_N=cfg.BLOCK_N, BLOCK_D=cfg.BLOCK_D,
        NUM_SPLITS=num_splits,
        HAS_MASK=has_mask, IS_CAUSAL=bool(causal),
        SCALE=scale,
        num_warps=cfg.num_warps,
        num_stages=cfg.num_stages,
    )

    # Stage 2: 1D grid
    grid2 = (num_qb,)
    flashattn_v2_stage2_kernel[grid2](
        out,
        m_partial, l_partial, acc_partial,
        N=N, D=D,
        stride_on=stride_on, stride_od=stride_od,
        STRIDE_PM_QB=STRIDE_PM_QB, STRIDE_PM_SPLIT=STRIDE_PM_SPLIT,
        STRIDE_PL_QB=STRIDE_PL_QB, STRIDE_PL_SPLIT=STRIDE_PL_SPLIT,
        STRIDE_PACC_QB=STRIDE_PACC_QB, STRIDE_PACC_SPLIT=STRIDE_PACC_SPLIT,
        STRIDE_PACC_M=STRIDE_PACC_M, STRIDE_PACC_D=STRIDE_PACC_D,
        BLOCK_M=cfg.BLOCK_M, BLOCK_D=cfg.BLOCK_D,
        NUM_SPLITS=num_splits,
        OUT_DTYPE=out_tl_dtype,
        num_warps=cfg.num_warps,
    )
    return out


# ============================================================
# Torch reference (correctness)
# ============================================================
def flashattn_ref_torch(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    causal: bool = False,
) -> torch.Tensor:
    _assert_2d(Q, "Q"); _assert_2d(K, "K"); _assert_2d(V, "V")
    N, D = Q.shape
    q = Q.float()
    k = K.float()
    v = V.float()

    scores = (q @ k.transpose(0, 1)) * (1.0 / math.sqrt(D))
    if causal:
        cm = torch.triu(torch.ones((N, N), device=Q.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(cm, float("-inf"))
    if mask is not None:
        scores = scores + mask.float()

    P = torch.softmax(scores, dim=-1)
    O = P @ v
    return O.to(dtype=Q.dtype)


@torch.no_grad()
def check_correctness(
    N=1024, D=64, dtype=torch.float16,
    use_mask=False, causal=True,
    cfg_v1: FlashV1Cfg = FlashV1Cfg(),
    cfg_v2: FlashV2Cfg = FlashV2Cfg(),
):
    device = "cuda"
    torch.manual_seed(0)

    Q = torch.randn((N, D), device=device, dtype=dtype)
    K = torch.randn((N, D), device=device, dtype=dtype)
    V = torch.randn((N, D), device=device, dtype=dtype)

    mask = None
    if use_mask:
        drop_prob = 0.05
        drop = (torch.rand((N, N), device=device) < drop_prob)
        mask = torch.zeros((N, N), device=device, dtype=torch.float32)
        mask = mask.masked_fill(drop, float("-inf"))
        diag = torch.eye(N, device=device, dtype=torch.bool)
        mask = mask.masked_fill(diag, 0.0)

    out_ref = flashattn_ref_torch(Q, K, V, mask=mask, causal=causal)
    out_v1 = flashattn_v1_triton(Q, K, V, mask=mask, causal=causal, cfg=cfg_v1)
    out_v2 = flashattn_v2_splitk_triton(Q, K, V, mask=mask, causal=causal, cfg=cfg_v2)

    diff1 = (out_ref.float() - out_v1.float()).abs()
    diff2 = (out_ref.float() - out_v2.float()).abs()

    print(f"[Correctness] N={N} D={D} dtype={dtype} use_mask={use_mask} causal={causal}")
    print(f"  v1 max_abs={diff1.max().item():.3e} mean_abs={diff1.mean().item():.3e}")
    print(f"  v2 max_abs={diff2.max().item():.3e} mean_abs={diff2.mean().item():.3e}")


@torch.no_grad()
def compare_perf(
    N_list: List[int] = [256, 512, 1024, 2048],
    D: int = 64,
    dtype=torch.float16,
    causal: bool = True,
    cfg_v1: FlashV1Cfg = FlashV1Cfg(),
    cfg_v2: FlashV2Cfg = FlashV2Cfg(),
):
    device = "cuda"
    torch.manual_seed(0)

    print("| N | v1 ms | v2 ms | v2/v1 speedup |")
    print("|---|------:|------:|--------------:|")

    for N in N_list:
        Q = torch.randn((N, D), device=device, dtype=dtype)
        K = torch.randn((N, D), device=device, dtype=dtype)
        V = torch.randn((N, D), device=device, dtype=dtype)

        # no materialized mask for perf (causal handled inside kernels)
        def fn_v1():
            return flashattn_v1_triton(Q, K, V, mask=None, causal=causal, cfg=cfg_v1)

        def fn_v2():
            return flashattn_v2_splitk_triton(Q, K, V, mask=None, causal=causal, cfg=cfg_v2)

        ms_v1 = cuda_time_ms(fn_v1)
        ms_v2 = cuda_time_ms(fn_v2)
        speedup = (ms_v1 / ms_v2) if ms_v2 > 0 else float("inf")
        print(f"| {N} | {ms_v1:7.4f} | {ms_v2:7.4f} | {speedup:14.2f}x |")


def main():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA required.")

    # T4-friendly starters
    cfg_v1 = FlashV1Cfg(BLOCK_M=64, BLOCK_N=64, BLOCK_D=64, num_warps=4, num_stages=1)
    cfg_v2 = FlashV2Cfg(BLOCK_M=64, BLOCK_N=64, BLOCK_D=64, num_splits=4, num_warps=4, num_stages=1)

    # correctness on small N first
    check_correctness(N=256, D=64, causal=True, use_mask=False, cfg_v1=cfg_v1, cfg_v2=cfg_v2)

    # perf
    compare_perf(N_list=[256, 512, 1024, 2048], D=64, causal=True, cfg_v1=cfg_v1, cfg_v2=cfg_v2)


if __name__ == "__main__":
    main()

[Correctness] N=256 D=64 dtype=torch.float16 use_mask=False causal=True
  v1 max_abs=2.441e-04 mean_abs=4.479e-08
  v2 max_abs=nan mean_abs=nan
| N | v1 ms | v2 ms | v2/v1 speedup |
|---|------:|------:|--------------:|
| 256 |  0.2352 |  0.3407 |           0.69x |
| 512 |  0.4575 |  0.9071 |           0.50x |
| 1024 |  0.8932 |  3.6534 |           0.24x |
| 2048 |  2.3417 | 12.0738 |           0.19x |
