# Week 2, Day 7: Optimized GEMM — Putting It All Together

**Time:** ~1 hour

**Goal:** Combine all optimizations to build a production-quality GEMM achieving 80%+ of peak.

## The Challenge

Week 1: Naive → 500 GFLOPS  
Week 2 Goal: Optimized → **80%+ of cuBLAS**

We'll use every technique from this week.

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench

---
## Step 1: The Challenge (5 min)

### Optimization Checklist

| Day | Technique | Status |
|-----|-----------|--------|
| 1 | Profiling & Metrics | ✓ Can measure |
| 2 | Coalesced Memory Access | ✓ |
| 3 | Bank Conflict Free SMEM | ✓ |
| 4 | Software Pipelining | ✓ |
| 5 | TMA (Hopper) | ✓ (if available) |
| 6 | Tensor Cores | ✓ |

Let's combine them all!

---
## Step 2: Explore (15 min)

### The Ultimate GEMM Recipe

```
1. Load tiles via TMA (or coalesced loads with pipelining)
2. Store in SMEM with padding (no bank conflicts)
3. Compute via Tensor Cores (tl.dot with FP16)
4. Accumulate in FP32 (precision)
5. Pipeline loads with compute (latency hiding)
6. Tune tile sizes via autotuning
```

In [None]:
# Full configuration space for autotuning
def get_autotune_configs():
    """Generate autotuning configurations."""
    configs = []
    
    # Block sizes: larger = more data reuse, but more SMEM
    block_sizes = [
        (64, 64, 32),
        (128, 64, 32),
        (64, 128, 32),
        (128, 128, 32),
        (128, 128, 64),
        (256, 64, 32),
        (64, 256, 32),
    ]
    
    # Pipeline stages
    stages_options = [2, 3, 4]
    
    # Warps per block
    warps_options = [4, 8]
    
    for bm, bn, bk in block_sizes:
        for stages in stages_options:
            for warps in warps_options:
                configs.append(
                    triton.Config(
                        {'BLOCK_M': bm, 'BLOCK_N': bn, 'BLOCK_K': bk},
                        num_stages=stages,
                        num_warps=warps,
                    )
                )
    
    return configs

print(f"Total configurations: {len(get_autotune_configs())}")

---
## Step 3: The Concept (10 min)

### Key Performance Factors

1. **Tile Size**: Balance between data reuse and occupancy
2. **Pipeline Depth**: Hide memory latency
3. **Warp Count**: Enough threads to saturate Tensor Cores
4. **Data Type**: FP16 for Tensor Cores, FP32 accumulator

### Theoretical Peak Analysis

```
H100 FP16 Tensor Core: 990 TFLOPS
H100 HBM Bandwidth: 3.35 TB/s

For GEMM C = A × B:
  FLOPS = 2 × M × N × K
  Bytes = 2 × (M×K + K×N + M×N)  [FP16]

Arithmetic Intensity = FLOPS / Bytes
Balance Point = 990 TFLOPS / 3.35 TB/s = 295 FLOPS/byte

For 4096³ GEMM:
  AI = 2×4096³ / (2×3×4096²) ≈ 1365 FLOPS/byte
  → COMPUTE BOUND (good for Tensor Cores!)
```

---
## Step 4: Code It (30 min)

### The Production GEMM Kernel

In [None]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=2, num_warps=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def gemm_optimized(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Production-quality GEMM kernel.
    
    Optimizations:
    - Tensor Cores via tl.dot (FP16 input, FP32 accumulator)
    - Software pipelining (via num_stages)
    - Autotuned tile sizes
    - Coalesced memory access
    """
    # Program IDs
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    # Swizzle for better L2 cache utilization
    # Groups programs to improve spatial locality
    num_pid_n = tl.cdiv(N, BLOCK_N)
    GROUP_SIZE = 8
    pid_m_group = pid_m // GROUP_SIZE
    pid_m_in_group = pid_m % GROUP_SIZE
    pid_n_swizzled = pid_m_group * GROUP_SIZE + pid_m_in_group
    pid_m_swizzled = pid_n
    
    # For simplicity, use direct mapping (swizzle can be added)
    pid_m_final = pid_m
    pid_n_final = pid_n
    
    # Block offsets
    offs_m = pid_m_final * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n_final * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    # Initial pointers
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    # Accumulator (FP32 for precision)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    # Main loop over K dimension
    for k in range(0, K, BLOCK_K):
        # Boundary masks
        mask_a = (offs_m[:, None] < M) & (offs_k[None, :] + k < K)
        mask_b = (offs_k[:, None] + k < K) & (offs_n[None, :] < N)
        
        # Load tiles (compiler will pipeline these)
        a = tl.load(a_ptrs, mask=mask_a, other=0.0)
        b = tl.load(b_ptrs, mask=mask_b, other=0.0)
        
        # Tensor Core MMA
        acc += tl.dot(a, b)
        
        # Advance pointers
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    # Store result
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    mask_c = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=mask_c)

In [None]:
def matmul_optimized(a, b):
    """Wrapper for optimized GEMM."""
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, f"Inner dimensions must match: {K} vs {K2}"
    
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)
    
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
    
    gemm_optimized[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

In [None]:
def benchmark_final(M, N, K):
    """Final benchmark comparing optimized kernel to cuBLAS."""
    a = torch.randn(M, K, device='cuda', dtype=torch.float16)
    b = torch.randn(K, N, device='cuda', dtype=torch.float16)
    
    # Warmup and autotune
    _ = matmul_optimized(a, b)
    
    # Benchmark Triton
    ms_triton = do_bench(lambda: matmul_optimized(a, b))
    
    # Benchmark cuBLAS (PyTorch)
    ms_cublas = do_bench(lambda: torch.mm(a, b))
    
    # Calculate metrics
    flops = 2 * M * N * K
    tflops_triton = flops / (ms_triton * 1e-3) / 1e12
    tflops_cublas = flops / (ms_cublas * 1e-3) / 1e12
    
    # Verify correctness
    c_triton = matmul_optimized(a, b)
    c_cublas = torch.mm(a, b).float()
    is_correct = torch.allclose(c_triton, c_cublas, rtol=1e-2, atol=1e-2)
    
    return {
        'triton_ms': ms_triton,
        'cublas_ms': ms_cublas,
        'triton_tflops': tflops_triton,
        'cublas_tflops': tflops_cublas,
        'efficiency': tflops_triton / tflops_cublas * 100,
        'correct': is_correct,
    }

print("Optimized GEMM Benchmark")
print("=" * 70)
print(f"{'Size':<15} {'Triton':<12} {'cuBLAS':<12} {'Efficiency':<12} {'Correct':<8}")
print(f"{'':15} {'(TFLOPS)':<12} {'(TFLOPS)':<12} {'(vs cuBLAS)':<12}")
print("-" * 70)

for size in [1024, 2048, 4096, 8192]:
    result = benchmark_final(size, size, size)
    status = "PASS" if result['correct'] else "FAIL"
    print(f"{f'{size}x{size}':<15} {result['triton_tflops']:<12.1f} {result['cublas_tflops']:<12.1f} "
          f"{result['efficiency']:<12.1f}% {status:<8}")

### Performance Analysis

In [None]:
def analyze_performance():
    """Detailed performance analysis."""
    # Get GPU info
    props = torch.cuda.get_device_properties(0)
    
    print(f"GPU: {props.name}")
    print(f"SM Count: {props.multi_processor_count}")
    print(f"Memory: {props.total_memory / 1e9:.1f} GB")
    print()
    
    # Benchmark across sizes
    print("Performance scaling:")
    print("-" * 50)
    
    sizes = [512, 1024, 2048, 4096, 8192, 16384]
    results = []
    
    for size in sizes:
        if size > 8192 and props.total_memory < 24e9:
            print(f"{size}x{size}: Skipped (insufficient memory)")
            continue
        
        try:
            result = benchmark_final(size, size, size)
            results.append((size, result))
            print(f"{size}x{size}: {result['triton_tflops']:.1f} TFLOPS ({result['efficiency']:.1f}% of cuBLAS)")
        except Exception as e:
            print(f"{size}x{size}: Error - {e}")
    
    return results

results = analyze_performance()

---
## Step 5: Verify (10 min)

### Final Checklist

Our optimized GEMM should achieve:

- [ ] 80%+ of cuBLAS performance on large matrices
- [ ] Correct results (< 1% error vs cuBLAS)
- [ ] Uses Tensor Cores (high TFLOPS)
- [ ] Coalesced memory access
- [ ] Software pipelining
- [ ] Autotuned for different sizes

In [None]:
def verify_optimization_goals():
    """Verify we met Week 2 goals."""
    print("Week 2 Goals Verification")
    print("=" * 50)
    
    # Test on 4096x4096 (representative size)
    result = benchmark_final(4096, 4096, 4096)
    
    goals = [
        ("80%+ of cuBLAS", result['efficiency'] >= 80),
        ("Correct results", result['correct']),
        ("High TFLOPS (>100)", result['triton_tflops'] > 100),
    ]
    
    all_passed = True
    for goal, passed in goals:
        status = "PASS" if passed else "FAIL"
        print(f"  {goal}: {status}")
        if not passed:
            all_passed = False
    
    print()
    if all_passed:
        print("All Week 2 goals achieved!")
    else:
        print("Some goals not met - may need hardware-specific tuning.")
    
    return all_passed

verify_optimization_goals()

---
## Summary

### Week 2 Achievements

| Day | Topic | Key Learning |
|-----|-------|-------------|
| 1 | Profiling | Measure before optimizing |
| 2 | Coalescing | Adjacent threads → adjacent memory |
| 3 | Bank Conflicts | Stride-33 beats stride-32 |
| 4 | Pipelining | Overlap loads with compute |
| 5 | TMA | Hopper's address-free loads |
| 6 | Tensor Cores | FP16 for 66x speedup |
| 7 | Integration | Combine all techniques |

### Journey So Far

```
Week 1 Day 1:  Naive Python       →  0.001 GFLOPS
Week 1 Day 2:  NumPy              →  50 GFLOPS
Week 1 Day 2:  CuPy               →  5,000 GFLOPS
Week 1 Day 7:  Tiled Triton       →  500 GFLOPS
Week 2 Day 7:  Optimized Triton   →  80%+ of cuBLAS!
```

### What's Next: Week 3

Now that we can write fast kernels, we'll tackle the **attention mechanism**:
- Dot products and softmax
- Numerical stability
- Online algorithms
- Building FlashAttention from scratch!