## Triton Core Kernels & Benchmarking

üéØ **Weekly Goal**  
Implement and tune three core ML kernels (MatMul, Softmax, LayerNorm) using Triton,  
develop intuition for **block sizes, `num_warps`, numerical stability, and kernel fusion**,  
and perform systematic benchmarking against PyTorch to understand real performance trade-offs.

---

## Tuned MatMul (Block Size / `num_warps`)

### Objective
Implement a tile-based Triton GEMM kernel and tune **`BLOCK_M / BLOCK_N / BLOCK_K`**
and **`num_warps`** to study their impact on throughput, occupancy, and register pressure.

### Tasks
- [ ] Implement a basic Triton matmul kernel (tile-based)
- [ ] Define `BLOCK_M / BLOCK_N / BLOCK_K` as `tl.constexpr`
- [ ] Evaluate multiple tile configurations (e.g., 64√ó64√ó32, 128√ó128√ó32)
- [ ] Sweep `num_warps ‚àà {4, 8}` and compare performance
- [ ] Compare against `torch.matmul` as a baseline

### Key Concepts
- **Tiling**: each program computes one output tile
- **Block size** controls arithmetic intensity (FLOPs / byte)
- **`num_warps`** trades off parallelism vs. register usage
- GEMM is typically **compute-bound**, unlike vector add

### Deliverables
- Runnable `triton_matmul.py`
- Performance table (ms / TFLOPS) for different configurations
- Short analysis identifying the best configuration and why

---

## Triton Softmax (Numerically Stable)

### Objective
Implement a numerically stable row-wise softmax kernel in Triton,
focusing on **max-subtraction**, reduction patterns, and block/warp mapping.

### Tasks
- [ ] Implement row-wise softmax in Triton
- [ ] Apply `x - max(x)` for numerical stability
- [ ] Decompose into two stages:
  - max reduction
  - exp + sum reduction
- [ ] Correctly handle arbitrary feature dimensions (non-power-of-two)
- [ ] Compare against `torch.softmax`

### Key Concepts
- **Numerical stability** for exponential operations
- **Reduction patterns** within a program
- Softmax is often **memory-bound with reductions**
- Triton enables explicit control over reduction structure

### Deliverables
- `triton_softmax.py`
- Correctness check vs. PyTorch (max / mean error)
- Performance comparison table (ms / GB/s)

---

## Triton LayerNorm

### Objective
Implement Triton LayerNorm (forward pass) and understand
**mean/variance reductions**, `eps` stabilization, and the performance benefits of kernel fusion.

### Tasks
- [ ] Implement LayerNorm forward in Triton
- [ ] Compute per-row mean and variance
- [ ] Apply `rsqrt(var + eps)`
- [ ] Support affine parameters (`gamma`, `beta`)
- [ ] Compare against `torch.nn.functional.layer_norm`

### Key Concepts
- Two reductions: mean ‚Üí variance
- **Kernel fusion**: normalization + affine in a single kernel
- LayerNorm is typically **memory-bound with reductions**
- Triton avoids intermediate tensor materialization

### Deliverables
- `triton_layernorm.py`
- Correctness validation (max / mean error)
- Triton vs. PyTorch performance comparison

---

## Benchmark: Triton vs PyTorch

### Objective
Systematically benchmark Triton kernels against PyTorch eager kernels
to identify **when Triton wins, when it does not, and why**.

### Tasks
- [ ] Build a unified benchmark framework (CUDA events)
- [ ] Compare the following operators:
  - vector add
  - fused add + ReLU
  - softmax
  - layernorm
- [ ] Record:
  - latency (ms)
  - effective bandwidth / FLOPs
  - speedup
- [ ] Repeat experiments across different tensor sizes

### Key Concepts
- **Bandwidth-bound vs. compute-bound** kernels
- Kernel launch overhead
- Real benefits of operator fusion
- Why Triton excels at fused kernels rather than single primitive ops

### Deliverables
- `benchmark_triton_vs_torch.py`
- Unified comparison table (Markdown / CSV)
- Summary covering:
  - which kernels benefit most from Triton
  - which PyTorch kernels are already near-optimal
  - implications for ML systems optimization

---

## End-of-Week Takeaways

- Triton is **not** ‚Äúa faster PyTorch‚Äù
- Triton enables **CUDA-level kernel design in Python**
- The real performance gains come from:
  - kernel fusion
  - explicit reduction control
  - tile-aware kernel design


In [None]:
#matmul_skeleton
import torch
import triton
import triton.language as tl

# ============================================================
# Day 3: Tuned MatMul Skeleton (NO SOLUTION)
# Goal:
#   - Implement a tile-based matmul kernel in Triton
#   - Tune BLOCK_M / BLOCK_N / BLOCK_K and num_warps
#   - Validate vs torch.matmul
#
# Notes:
#   - Assume A: [M, K], B: [K, N], C: [M, N], fp16 inputs, fp16 output (acc fp32).
#   - You may start with fp16 and accumulate in fp32.
# ============================================================

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # TODO:
    # 1) program ids for 2D tiling: pid_m, pid_n
    # 2) compute offsets for A tile and B tile
    # 3) loop over K tiles:
    #    - tl.load A and B tiles with masks
    #    - accumulate using tl.dot / manual FMA
    # 4) tl.store to C with mask for M,N boundaries
    #
    # Hints:
    #   pid_m = tl.program_id(0)
    #   pid_n = tl.program_id(1)
    #   offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    #   offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    #   offs_k = tl.arange(0, BLOCK_K)
    #
    #   use tl.multiple_of / tl.assume if needed (optional)
    # raise NotImplementedError("TODO: implement matmul_kernel")

    # Compute row/col offsets for this program's C tile
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute row/col offsets for this program's C tile
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)


    m_mask = offs_m < M          # [BM]
    n_mask = offs_n < N          # [BN]

    for k0 in range(0, K, BLOCK_K):
        # current K indices for this chunk
        k_offsets = k0 + offs_k          # [BK]
        k_mask = k_offsets < K           # [BK]

        # build pointer grids for this chunk
        a_ptrs = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_ak   # [BM, BK]
        b_ptrs = b_ptr + k_offsets[:, None] * stride_bk + offs_n[None, :] * stride_bn   # [BK, BN]

        # 2D masks for loads
        a_load_mask = m_mask[:, None] & k_mask[None, :]    # [BM, BK]
        b_load_mask = k_mask[:, None] & n_mask[None, :]    # [BK, BN]

        # masked loads: out-of-bounds => 0
        a_tile = tl.load(a_ptrs, mask=a_load_mask, other=0.0)  # [BM, BK], fp16/bf16
        b_tile = tl.load(b_ptrs, mask=b_load_mask, other=0.0)  # [BK, BN], fp16/bf16

        # accumulate (fp32)
        # tl.dot will typically accumulate in fp32 when acc is fp32
        acc += tl.dot(a_tile, b_tile)

    c_tile = acc.to(tl.float16)
    c_ptrs = c_ptr + offs_m[:,None]*stride_cm + offs_n[None,:]*stride_cn
    tl.store(c_ptrs, c_tile, mask=m_mask[:,None] & n_mask[None,:])

def triton_matmul(A: torch.Tensor, B: torch.Tensor,
                  BLOCK_M=128, BLOCK_N=128, BLOCK_K=32,
                  num_warps=8):
    assert A.is_cuda and B.is_cuda
    assert A.dtype in (torch.float16, torch.bfloat16)
    assert B.dtype in (torch.float16, torch.bfloat16)
    assert A.is_contiguous() and B.is_contiguous()
    M, K = A.shape
    K2, N = B.shape
    assert K == K2

    C = torch.empty((M, N), device=A.device, dtype=A.dtype)

    grid = (
        triton.cdiv(M, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    matmul_kernel[grid](
        A, B, C,
        M, N, K,
        A.stride(0), A.stride(1),
        B.stride(0), B.stride(1),
        C.stride(0), C.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        num_warps=num_warps,   # Triton launch meta
    )
    return C

@torch.no_grad()
def check_matmul():
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return

    device = "cuda"
    torch.manual_seed(0)

    M, K, N = 512, 1024, 768
    A = torch.randn((M, K), device=device, dtype=torch.float16)
    B = torch.randn((K, N), device=device, dtype=torch.float16)

    # PyTorch baseline
    C_ref = A @ B


    # Triton (will fail until you implement kernel)
    try:
        C_tri = triton_matmul(A, B, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, num_warps=8)
    except NotImplementedError as e:
        print(f"[Day3] matmul kernel not implemented yet: {e}")
        return


    C_ref_fp16 = A @ B                 # PyTorch fp16 Ë∑ØÂæÑ
    C_ref_fp32 = A.float() @ B.float() # Êõ¥‰∏•Ê†º reference

    # ‰Ω†ÁöÑËæìÂá∫ C_tri ÊòØ fp16
    print("err vs torch fp16:", (C_tri - C_ref_fp16).abs().max().item())
    print("err vs fp32 ref  :", (C_tri.float() - C_ref_fp32).abs().max().item())

    # correctness
    max_err = (C_tri - C_ref).abs().max().item()
    print(f"[Day3] max_abs_err = {max_err:.3e}")
    # You can tighten thresholds after tuning
    assert max_err < 7e-2, "Too large error (fp16). Improve implementation."

if __name__ == "__main__":
    check_matmul()

err vs torch fp16: 0.0625
err vs fp32 ref  : 0.0601959228515625
[Day3] max_abs_err = 6.250e-02


In [None]:
#matmul_skeleton soft_pipeline
import torch
import triton
import triton.language as tl

# ============================================================
# Day 3: Tuned MatMul Skeleton (NO SOLUTION)
# Goal:
#   - Implement a tile-based matmul kernel in Triton
#   - Tune BLOCK_M / BLOCK_N / BLOCK_K and num_warps
#   - Validate vs torch.matmul
#
# Notes:
#   - Assume A: [M, K], B: [K, N], C: [M, N], fp16 inputs, fp16 output (acc fp32).
#   - You may start with fp16 and accumulate in fp32.
# ============================================================

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    MAX_K_TILES: tl.constexpr,            # [NEW] meta: static_range upperboundÔºàavoid 1024 compileÔºâ
    OUT_DTYPE: tl.constexpr,              # [NEW] meta: 0->fp16, 1->bf16
):
    # TODO:
    # 1) program ids for 2D tiling: pid_m, pid_n
    # 2) compute offsets for A tile and B tile
    # 3) loop over K tiles:
    #    - tl.load A and B tiles with masks
    #    - accumulate using tl.dot / manual FMA
    # 4) tl.store to C with mask for M,N boundaries
    #
    # Hints:
    #   pid_m = tl.program_id(0)
    #   pid_n = tl.program_id(1)
    #   offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    #   offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    #   offs_k = tl.arange(0, BLOCK_K)
    #
    #   use tl.multiple_of / tl.assume if needed (optional)
    # raise NotImplementedError("TODO: implement matmul_kernel")

    # Compute row/col offsets for this program's C tile
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Compute row/col offsets for this program's C tile
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)


    m_mask = offs_m < M          # [BM]
    n_mask = offs_n < N          # [BN]
    #for k0 in range(0, K, BLOCK_K):
        #k_offsets = k0 + offs_k          # [BK]
    # Software Pipelining
    #time ‚Üí
    #    iter 0:  load
    #    iter 1:       load     compute
    #    iter 2:              load     compute
    #    iter 3:                     load     compute
    k_tiles = tl.cdiv(K, BLOCK_K)
    for k_it in tl.static_range(0, MAX_K_TILES):
        # if k_it >= k_tiles:
        #    break
        valid_k_iter = k_it < k_tiles

        # k0 = k_it * BLOCK_K
        # k_offsets = k0 + offs_k
        k_offsets = k_it * BLOCK_K + offs_k

        k_mask = k_offsets < K           # [BK]

        final_k_mask = valid_k_iter & k_mask   # [BK]

        # build pointer grids for this chunk
        a_ptrs = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_ak   # [BM, BK]
        b_ptrs = b_ptr + k_offsets[:, None] * stride_bk + offs_n[None, :] * stride_bn   # [BK, BN]

        # 2D masks for loads
        a_load_mask = m_mask[:, None] & final_k_mask[None, :]    # [BM, BK]
        b_load_mask = final_k_mask[:, None] & n_mask[None, :]    # [BK, BN]

        # masked loads: out-of-bounds => 0
        a_tile = tl.load(a_ptrs, mask=a_load_mask, other=0.0)  # [BM, BK], fp16/bf16
        b_tile = tl.load(b_ptrs, mask=b_load_mask, other=0.0)  # [BK, BN], fp16/bf16

        # accumulate (fp32)
        # tl.dot will typically accumulate in fp32 when acc is fp32
        acc += tl.dot(a_tile, b_tile)

    c_tile = tl.where(OUT_DTYPE == 1, acc.to(tl.bfloat16), acc.to(tl.float16))
    c_ptrs = c_ptr + offs_m[:,None]*stride_cm + offs_n[None,:]*stride_cn
    tl.store(c_ptrs, c_tile, mask=m_mask[:,None] & n_mask[None,:])

def triton_matmul(A: torch.Tensor, B: torch.Tensor,
                  BLOCK_M=128, BLOCK_N=128, BLOCK_K=32,
                  num_warps=8, num_stages=4):
    assert A.is_cuda and B.is_cuda
    assert A.dtype in (torch.float16, torch.bfloat16)
    assert B.dtype in (torch.float16, torch.bfloat16)
    assert A.is_contiguous() and B.is_contiguous()
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    assert A.dtype == B.dtype, "For now, require A and B have same dtype (fp16 or bf16)."

    C = torch.empty((M, N), device=A.device, dtype=A.dtype)

    grid = (
        triton.cdiv(M, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    MAX_K_TILES = triton.cdiv(K, BLOCK_K)

    OUT_DTYPE = 1 if A.dtype == torch.bfloat16 else 0

    matmul_kernel[grid](
        A, B, C,
        M, N, K,
        A.stride(0), A.stride(1),
        B.stride(0), B.stride(1),
        C.stride(0), C.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        MAX_K_TILES=MAX_K_TILES,
        OUT_DTYPE=OUT_DTYPE,
        num_warps=num_warps,   # Triton launch meta
        num_stages=num_stages,

    )
    return C

@torch.no_grad()
def check_matmul():
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return

    device = "cuda"
    torch.manual_seed(0)

    M, K, N = 512, 1024, 768
    A = torch.randn((M, K), device=device, dtype=torch.float16)
    B = torch.randn((K, N), device=device, dtype=torch.float16)

    # PyTorch baseline
    C_ref = A @ B

    # Triton (will fail until you implement kernel)
    try:
        C_tri = triton_matmul(A, B, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, num_warps=8, num_stages=4)
    except NotImplementedError as e:
        print(f"[Day3] matmul kernel not implemented yet: {e}")
        return


    C_ref_fp16 = A @ B                 # PyTorch fp16 Ë∑ØÂæÑ
    C_ref_fp32 = A.float() @ B.float() # Êõ¥‰∏•Ê†º reference

    # ‰Ω†ÁöÑËæìÂá∫ C_tri ÊòØ fp16
    print("err vs torch fp16:", (C_tri - C_ref_fp16).abs().max().item())
    print("err vs fp32 ref  :", (C_tri.float() - C_ref_fp32).abs().max().item())
    # correctness
    max_err = (C_tri - C_ref).abs().max().item()
    print(f"[Day3] max_abs_err = {max_err:.3e}")
    # You can tighten thresholds after tuning
    assert max_err < 7e-2, "Too large error (fp16). Improve implementation."

if __name__ == "__main__":
    check_matmul()

err vs torch fp16: 0.0625
err vs fp32 ref  : 0.0601959228515625
[Day3] max_abs_err = 6.250e-02


In [5]:
import time
import torch
import triton
import triton.language as tl
# ============================================================
# Day 4: Numerically Stable Softmax Skeleton (NO SOLUTION)
# Goal:
#   - Implement row-wise stable softmax: y = exp(x - max) / sum(exp(x - max))
#   - Handle any D (not necessarily power of two)
#   - Validate vs torch.softmax
# ============================================================
    # TODO:
    # 1) pid = tl.program_id(0) for row index
    # 2) offsets = tl.arange(0, BLOCK_D)
    # 3) load x row block(s) with mask
    # 4) compute max over D (may need multiple loads if D > BLOCK_D)
    # 5) compute exp(x - max), sum, and normalize
    #
    # Minimal baseline is "one program handles one row", and choose BLOCK_D >= D for first version.
    # Then extend to D > BLOCK_D using multiple chunks.
# ============================================================
# Kernel A: One program handles one row
# Requirement: BLOCK_D >= D
# Single-pass: read once, write once
# ============================================================

@triton.jit
def softmax_kernel_A(
    x_ptr, y_ptr,
    B, D,
    stride_xb, stride_xd,
    stride_yb, stride_yd,
    BLOCK_D: tl.constexpr,
):
    pid = tl.program_id(0)  # row index
    offs = tl.arange(0, BLOCK_D)

    x_row_ptr = x_ptr + pid * stride_xb + offs * stride_xd
    y_row_ptr = y_ptr + pid * stride_yb + offs * stride_yd

    mask = offs < D

    x = tl.load(x_row_ptr, mask=mask, other=-float("inf"))
    x_max = tl.max(x, axis=0)
    ex = tl.exp(x - x_max)
    denom = tl.sum(ex, axis=0)
    y = ex / denom

    tl.store(y_row_ptr, y, mask=mask)


# ============================================================
# Kernel B: Chunked version (supports D > BLOCK_D)
# 3 passes over the row
# D must be constexpr
# ============================================================

@triton.jit
def softmax_kernel_B(
    x_ptr, y_ptr,
    B,
    stride_xb, stride_xd,
    stride_yb, stride_yd,
    D: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid = tl.program_id(0)
    offs = tl.arange(0, BLOCK_D)

    # Pass 1: compute row max
    row_max = -float("inf")
    for start in tl.static_range(0, D, BLOCK_D):
        cols = start + offs
        mask = cols < D
        x = tl.load(x_ptr + pid * stride_xb + cols * stride_xd,
                    mask=mask, other=-float("inf"))
        row_max = tl.maximum(row_max, tl.max(x, axis=0))

    # Pass 2: compute denominator
    denom = 0.0
    for start in tl.static_range(0, D, BLOCK_D):
        cols = start + offs
        mask = cols < D
        x = tl.load(x_ptr + pid * stride_xb + cols * stride_xd,
                    mask=mask, other=-float("inf"))
        denom += tl.sum(tl.exp(x - row_max), axis=0)

    # Pass 3: write output
    for start in tl.static_range(0, D, BLOCK_D):
        cols = start + offs
        mask = cols < D
        x = tl.load(x_ptr + pid * stride_xb + cols * stride_xd,
                    mask=mask, other=-float("inf"))
        y = tl.exp(x - row_max) / denom
        tl.store(y_ptr + pid * stride_yb + cols * stride_yd,
                 y, mask=mask)


# ============================================================
# Benchmark helpers
# ============================================================

@torch.no_grad()
def bench(fn, iters=200, warmup=50):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()

    return start.elapsed_time(end) / iters  # ms/iter


def max_abs_err(a, b):
    return float((a - b).abs().max().item())


def compute_gbps(bytes_processed, ms):
    seconds = ms / 1000.0
    return bytes_processed / seconds / 1e9


# ============================================================
# Main experiment
# ============================================================

def main():

    torch.manual_seed(0)
    device = "cuda"
    dtype = torch.float32
    elem_size = torch.tensor([], dtype=dtype).element_size()

    B, D = 4096, 2048
    x = torch.randn((B, D), device=device, dtype=dtype)
    y_ref = torch.softmax(x, dim=1)

    print(f"\nGPU: {torch.cuda.get_device_name(0)}")
    print(f"Shape: B={B}, D={D}, dtype={dtype}")
    print("=" * 100)

    warps_list = [2, 4, 8]
    stages_list = [1, 2, 3]

    print(f"{'Kernel':<10} | {'warps':>5} | {'stages':>6} | {'ms':>8} | {'GB/s':>10} | {'max_err':>10}")
    print("-" * 100)

    # ------------------------
    # Kernel A
    # ------------------------
    BLOCK_D_A = triton.next_power_of_2(D)

    for w in warps_list:
        for s in stages_list:

            def run_A():
                y = torch.empty_like(x)
                grid = (B,)
                softmax_kernel_A[grid](
                    x, y,
                    B, D,
                    x.stride(0), x.stride(1),
                    y.stride(0), y.stride(1),
                    BLOCK_D=BLOCK_D_A,
                    num_warps=w,
                    num_stages=s,
                )
                return y

            yA = run_A()
            err = max_abs_err(yA, y_ref)

            ms = bench(run_A)

            bytes_processed = B * D * elem_size * 2  # 1 read + 1 write
            gbps = compute_gbps(bytes_processed, ms)

            print(f"{'A':<10} | {w:5d} | {s:6d} | {ms:8.4f} | {gbps:10.2f} | {err:10.3e}")

    # ------------------------
    # Kernel B
    # ------------------------
    BLOCK_D_B = 1024

    for w in warps_list:
        for s in stages_list:

            def run_B():
                y = torch.empty_like(x)
                grid = (B,)
                softmax_kernel_B[grid](
                    x, y,
                    B,
                    x.stride(0), x.stride(1),
                    y.stride(0), y.stride(1),
                    D=D,
                    BLOCK_D=BLOCK_D_B,
                    num_warps=w,
                    num_stages=s,
                )
                return y

            yB = run_B()
            err = max_abs_err(yB, y_ref)

            ms = bench(run_B)

            bytes_processed = B * D * elem_size * 4  # 3 reads + 1 write
            gbps = compute_gbps(bytes_processed, ms)

            print(f"{'B':<10} | {w:5d} | {s:6d} | {ms:8.4f} | {gbps:10.2f} | {err:10.3e}")

    print("=" * 100)
    print("Notes:")
    print(" - Kernel A performs 1 read + 1 write per element.")
    print(" - Kernel B performs 3 reads + 1 write per element.")
    print(" - num_warps affects parallelism and register pressure.")
    print(" - num_stages affects pipelining of memory loads.")
    print(" - For large D, kernel B scales better.")
    print("=" * 100)


if __name__ == "__main__":
    main()



GPU: Tesla T4
Shape: B=4096, D=2048, dtype=torch.float32
Kernel     | warps | stages |       ms |       GB/s |    max_err
----------------------------------------------------------------------------------------------------
A          |     2 |      1 |   0.2892 |     232.03 |  5.588e-09
A          |     2 |      2 |   0.2894 |     231.86 |  5.588e-09
A          |     2 |      3 |   0.2894 |     231.91 |  5.588e-09
A          |     4 |      1 |   0.2834 |     236.84 |  3.725e-09
A          |     4 |      2 |   0.2833 |     236.91 |  3.725e-09
A          |     4 |      3 |   0.2832 |     236.94 |  3.725e-09
A          |     8 |      1 |   0.2825 |     237.58 |  5.588e-09
A          |     8 |      2 |   0.2835 |     236.69 |  5.588e-09
A          |     8 |      3 |   0.2835 |     236.71 |  5.588e-09
B          |     2 |      1 |   0.2958 |     453.82 |  4.602e-02
B          |     2 |      2 |   0.2955 |     454.22 |  4.602e-02
B          |     2 |      3 |   0.2957 |     453.85 |  4.602e

In [None]:
#layernorm_skeleton
import torch
import triton
import triton.language as tl
import torch.nn.functional as F

# ============================================================
# Day 5: LayerNorm Forward Skeleton (NO SOLUTION)
# Goal:
#   - Implement LayerNorm forward:
#       y = (x - mean) * rsqrt(var + eps) * gamma + beta
#   - Validate vs torch.nn.functional.layer_norm
# ============================================================

@triton.jit
def layernorm_fwd_kernel(
    x_ptr, gamma_ptr, beta_ptr, y_ptr,
    B, D,
    stride_xb, stride_xd,
    stride_yb, stride_yd,
    eps,
    BLOCK_D: tl.constexpr,
):
    # TODO:
    # 1) pid = tl.program_id(0) => row
    # 2) load x row (possibly in chunks if D > BLOCK_D)
    # 3) compute mean and var
    # 4) normalize + affine
    # 5) store y
    raise NotImplementedError("TODO: implement layernorm_fwd_kernel")

def triton_layernorm(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor,
                     eps=1e-5, BLOCK_D=1024):
    assert x.is_cuda and gamma.is_cuda and beta.is_cuda
    assert x.dtype == torch.float32 and gamma.dtype == torch.float32 and beta.dtype == torch.float32
    assert x.is_contiguous() and gamma.is_contiguous() and beta.is_contiguous()
    B, D = x.shape
    y = torch.empty_like(x)
    grid = (B,)
    layernorm_fwd_kernel[grid](
        x, gamma, beta, y,
        B, D,
        x.stride(0), x.stride(1),
        y.stride(0), y.stride(1),
        eps,
        BLOCK_D=BLOCK_D,
        num_warps=4,
    )
    return y

@torch.no_grad()
def check_layernorm():
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return
    device = "cuda"
    torch.manual_seed(0)

    B, D = 4096, 1024
    x = torch.randn((B, D), device=device, dtype=torch.float32)
    gamma = torch.randn((D,), device=device, dtype=torch.float32)
    beta = torch.randn((D,), device=device, dtype=torch.float32)
    eps = 1e-5

    y_ref = F.layer_norm(x, (D,), gamma, beta, eps=eps)

    try:
        y_tri = triton_layernorm(x, gamma, beta, eps=eps, BLOCK_D=1024)
    except NotImplementedError as e:
        print(f"[Day5] layernorm kernel not implemented yet: {e}")
        return

    max_err = (y_tri - y_ref).abs().max().item()
    mean_err = (y_tri - y_ref).abs().mean().item()
    print(f"[Day5] max_abs_err = {max_err:.3e}, mean_abs_err = {mean_err:.3e}")
    assert max_err < 2e-4, "Too large error for fp32 layernorm."

if __name__ == "__main__":
    check_layernorm()


In [None]:
#benchmark_skeleton
import torch
import time
import importlib

# ============================================================
# Day 6: Benchmark Triton vs PyTorch for Day3/4/5
# - If kernel not implemented, it will skip and print a message.
# - Uses CUDA events for timing.
# ============================================================

@torch.no_grad()
def bench_ms(fn, iters=200, warmup=50):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    s.record()
    for _ in range(iters):
        fn()
    e.record()
    torch.cuda.synchronize()
    return s.elapsed_time(e) / iters

def print_table(rows):
    header = f"{'Op':<10} | {'Impl':<8} | {'ms/iter':>10} | {'speedup':>8} | {'note':<20}"
    print(header)
    print("-" * len(header))
    for r in rows:
        print(f"{r['op']:<10} | {r['impl']:<8} | {r['ms']:>10.4f} | {r['speedup']:>8.2f} | {r['note']:<20}")

def main():
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return

    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = "cuda"
    torch.manual_seed(0)

    rows = []

    # -------------------------
    # Day3 MatMul
    # -------------------------
    try:
        day3 = importlib.import_module("day3_matmul_skeleton")
        M, K, N = 1024, 1024, 1024
        A = torch.randn((M, K), device=device, dtype=torch.float16)
        B = torch.randn((K, N), device=device, dtype=torch.float16)

        torch_fn = lambda: A @ B
        torch_ms = bench_ms(torch_fn, iters=100, warmup=20)

        def triton_fn():
            return day3.triton_matmul(A, B, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, num_warps=8)

        try:
            tri_ms = bench_ms(triton_fn, iters=100, warmup=20)
            rows.append({"op":"matmul", "impl":"torch", "ms":torch_ms, "speedup":1.0, "note":"torch.matmul"})
            rows.append({"op":"matmul", "impl":"triton", "ms":tri_ms, "speedup":torch_ms/tri_ms, "note":"BM=128 BN=128"})
        except NotImplementedError as e:
            rows.append({"op":"matmul", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":"TODO kernel"})
    except Exception as e:
        rows.append({"op":"matmul", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":f"import fail: {e}"})


    # -------------------------
    # Day4 Softmax
    # -------------------------
    try:
        day4 = importlib.import_module("day4_softmax_skeleton")
        Bsz, D = 4096, 1024
        x = torch.randn((Bsz, D), device=device, dtype=torch.float32)

        torch_fn = lambda: torch.softmax(x, dim=1)
        torch_ms = bench_ms(torch_fn)

        def triton_fn():
            return day4.triton_softmax(x, BLOCK_D=1024)

        try:
            tri_ms = bench_ms(triton_fn)
            rows.append({"op":"softmax", "impl":"torch", "ms":torch_ms, "speedup":1.0, "note":"torch.softmax"})
            rows.append({"op":"softmax", "impl":"triton", "ms":tri_ms, "speedup":torch_ms/tri_ms, "note":"BLOCK_D=1024"})
        except NotImplementedError as e:
            rows.append({"op":"softmax", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":"TODO kernel"})
    except Exception as e:
        rows.append({"op":"softmax", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":f"import fail: {e}"})


    # -------------------------
    # Day5 LayerNorm
    # -------------------------
    try:
        day5 = importlib.import_module("day5_layernorm_skeleton")
        Bsz, D = 4096, 1024
        x = torch.randn((Bsz, D), device=device, dtype=torch.float32)
        gamma = torch.randn((D,), device=device, dtype=torch.float32)
        beta = torch.randn((D,), device=device, dtype=torch.float32)
        eps = 1e-5

        torch_fn = lambda: torch.nn.functional.layer_norm(x, (D,), gamma, beta, eps=eps)
        torch_ms = bench_ms(torch_fn)

        def triton_fn():
            return day5.triton_layernorm(x, gamma, beta, eps=eps, BLOCK_D=1024)

        try:
            tri_ms = bench_ms(triton_fn)
            rows.append({"op":"layernorm", "impl":"torch", "ms":torch_ms, "speedup":1.0, "note":"F.layer_norm"})
            rows.append({"op":"layernorm", "impl":"triton", "ms":tri_ms, "speedup":torch_ms/tri_ms, "note":"BLOCK_D=1024"})
        except NotImplementedError as e:
            rows.append({"op":"layernorm", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":"TODO kernel"})
    except Exception as e:
        rows.append({"op":"layernorm", "impl":"triton", "ms":float("nan"), "speedup":float("nan"), "note":f"import fail: {e}"})


    print()
    print_table(rows)

if __name__ == "__main__":
    main()
