# Week 1, Day 7: Your First Fast Matmul

**Time:** ~1 hour

**Goal:** Implement tiled matrix multiplication and achieve significant speedup.

## The Challenge

This is the culmination of Week 1. We'll combine everything:
- GPU architecture understanding (Day 3)
- Triton kernel programming (Day 4)
- Memory hierarchy and coalescing (Day 5)
- Tiling for data reuse (Day 6)

**Target:** Achieve 500+ GFLOPS (getting into the range of real performance)

In [None]:
import numpy as np
import time

try:
    import torch
    import triton
    import triton.language as tl
    GPU_AVAILABLE = True
    print(f"Triton: {triton.__version__}")
    print(f"PyTorch: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
except ImportError:
    GPU_AVAILABLE = False
    print("GPU libraries not available")

---
## Step 1: The Challenge (5 min)

### Matrix Multiplication: C = A @ B

```
A: M x K
B: K x N  
C: M x N

C[i,j] = sum(A[i,k] * B[k,j] for k in range(K))
```

### FLOPS Calculation

- Each output element: K multiplications + K-1 additions ≈ 2K operations
- Total elements: M × N
- **Total FLOPS: 2 × M × N × K**

In [None]:
def calculate_matmul_gflops(M, N, K, time_seconds):
    """Calculate GFLOPS for matmul."""
    flops = 2 * M * N * K
    gflops = flops / (time_seconds * 1e9)
    return gflops

# Reference: What we're aiming for
print("Performance targets:")
print(f"  NumPy (CPU): ~50-100 GFLOPS")
print(f"  Naive GPU kernel: ~100-500 GFLOPS")
print(f"  Tiled GPU kernel: ~500-2000 GFLOPS")
print(f"  cuBLAS/Triton autotuned: ~2000-5000 GFLOPS")
print(f"  Peak H100 (FP32): ~67 TFLOPS")

---
## Step 2: Explore - The Tiled Algorithm (15 min)

### Algorithm Overview

```
For each output tile (block_m, block_n):
    Initialize accumulator = zeros(TILE_M, TILE_N)
    
    For k_tile in range(0, K, TILE_K):
        # Load tiles from global memory
        tile_a = A[block_m*TILE_M : (block_m+1)*TILE_M, 
                   k_tile : k_tile+TILE_K]
        tile_b = B[k_tile : k_tile+TILE_K,
                   block_n*TILE_N : (block_n+1)*TILE_N]
        
        # Compute partial product (in registers/shared memory)
        accumulator += tile_a @ tile_b
    
    # Write result to global memory
    C[block_m*TILE_M : ..., block_n*TILE_N : ...] = accumulator
```

In [None]:
# Visualize the tiled algorithm
def visualize_tiled_matmul(M, N, K, TILE_M, TILE_N, TILE_K):
    """Show how tiled matmul processes data."""
    
    num_tiles_m = (M + TILE_M - 1) // TILE_M
    num_tiles_n = (N + TILE_N - 1) // TILE_N
    num_tiles_k = (K + TILE_K - 1) // TILE_K
    
    print(f"Matrix dimensions: A={M}x{K}, B={K}x{N}, C={M}x{N}")
    print(f"Tile sizes: TILE_M={TILE_M}, TILE_N={TILE_N}, TILE_K={TILE_K}")
    print()
    print(f"Grid of output tiles: {num_tiles_m} x {num_tiles_n} = {num_tiles_m * num_tiles_n} tiles")
    print(f"K-dimension tiles: {num_tiles_k} (loop iterations per output tile)")
    print()
    print("For each output tile, we iterate {num_tiles_k} times:")
    print("  - Load a {TILE_M}x{TILE_K} tile of A")
    print("  - Load a {TILE_K}x{TILE_N} tile of B")
    print("  - Compute partial matmul and accumulate")
    print()
    
    # Data reuse calculation
    a_loads_per_element = num_tiles_n  # Each A tile loaded for each column of C tiles
    b_loads_per_element = num_tiles_m  # Each B tile loaded for each row of C tiles
    
    print(f"Data reuse:")
    print(f"  Each A tile reused {num_tiles_n} times (once per column of C tiles)")
    print(f"  Each B tile reused {num_tiles_m} times (once per row of C tiles)")

visualize_tiled_matmul(1024, 1024, 1024, 64, 64, 32)

---
## Step 3: The Concept - Triton's Block-Level Programming (10 min)

### Key Triton Operations for Matmul

```python
# 1. Load 2D tiles
tile_a = tl.load(a_ptr + offsets, mask=mask)

# 2. Matrix multiplication on tiles  
# Triton has a built-in dot product:
result = tl.dot(tile_a, tile_b)  # Efficient matmul!

# 3. Accumulate
accumulator += result
```

`tl.dot()` is the key - Triton compiles this to efficient Tensor Core operations on compatible hardware.

---
## Step 4: Code It - Tiled Matmul Kernel (30 min)

In [None]:
if GPU_AVAILABLE:
    @triton.jit
    def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # Strides (elements to skip to go to next row)
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        # Tile sizes (must be compile-time constants)
        TILE_M: tl.constexpr,
        TILE_N: tl.constexpr,
        TILE_K: tl.constexpr,
    ):
        """Tiled matrix multiplication: C = A @ B"""
        
        # ===== Step 1: Identify which tile this program computes =====
        pid_m = tl.program_id(0)  # Row tile index
        pid_n = tl.program_id(1)  # Column tile index
        
        # ===== Step 2: Compute base offsets for this tile =====
        # Rows of A and C that this tile handles
        rm = pid_m * TILE_M + tl.arange(0, TILE_M)
        # Columns of B and C that this tile handles
        cn = pid_n * TILE_N + tl.arange(0, TILE_N)
        
        # ===== Step 3: Initialize accumulator =====
        # This holds the partial results for this output tile
        acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
        
        # ===== Step 4: Loop over K dimension in tiles =====
        for k_start in range(0, K, TILE_K):
            # K indices for this iteration
            rk = k_start + tl.arange(0, TILE_K)
            
            # ----- Load tile of A -----
            # A[rm, rk] - need 2D indexing
            # Offsets: rm[:, None] * stride_am + rk[None, :] * stride_ak
            a_offsets = rm[:, None] * stride_am + rk[None, :] * stride_ak
            a_mask = (rm[:, None] < M) & (rk[None, :] < K)
            a = tl.load(a_ptr + a_offsets, mask=a_mask, other=0.0)
            
            # ----- Load tile of B -----
            # B[rk, cn]
            b_offsets = rk[:, None] * stride_bk + cn[None, :] * stride_bn
            b_mask = (rk[:, None] < K) & (cn[None, :] < N)
            b = tl.load(b_ptr + b_offsets, mask=b_mask, other=0.0)
            
            # ----- Compute partial matmul -----
            # This is the magic: tl.dot compiles to efficient matmul
            acc += tl.dot(a, b)
        
        # ===== Step 5: Store result =====
        c_offsets = rm[:, None] * stride_cm + cn[None, :] * stride_cn
        c_mask = (rm[:, None] < M) & (cn[None, :] < N)
        tl.store(c_ptr + c_offsets, acc, mask=c_mask)
    
    print("Matmul kernel compiled!")

In [None]:
if GPU_AVAILABLE:
    def matmul_triton(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """Matrix multiplication using our Triton kernel."""
        
        assert a.is_cuda and b.is_cuda
        assert a.shape[1] == b.shape[0], f"Shape mismatch: {a.shape} x {b.shape}"
        
        M, K = a.shape
        K, N = b.shape
        
        c = torch.empty((M, N), device=a.device, dtype=a.dtype)
        
        # Tile sizes (tune these for your GPU!)
        TILE_M, TILE_N, TILE_K = 64, 64, 32
        
        # Grid: one program per output tile
        grid = (
            triton.cdiv(M, TILE_M),
            triton.cdiv(N, TILE_N),
        )
        
        matmul_kernel[grid](
            a, b, c,
            M, N, K,
            a.stride(0), a.stride(1),
            b.stride(0), b.stride(1),
            c.stride(0), c.stride(1),
            TILE_M=TILE_M,
            TILE_N=TILE_N,
            TILE_K=TILE_K,
        )
        
        return c
    
    print("Wrapper function defined!")

In [None]:
if GPU_AVAILABLE:
    # Test correctness
    print("Testing correctness...")
    
    M, N, K = 512, 512, 512
    a = torch.randn(M, K, device='cuda', dtype=torch.float32)
    b = torch.randn(K, N, device='cuda', dtype=torch.float32)
    
    # Our kernel
    c_triton = matmul_triton(a, b)
    
    # Reference (PyTorch)
    c_torch = a @ b
    
    # Check
    max_diff = (c_triton - c_torch).abs().max().item()
    print(f"Max difference: {max_diff:.6f}")
    print(f"Relative error: {max_diff / c_torch.abs().mean().item():.6f}")
    
    # Allow small numerical differences due to float32
    is_correct = torch.allclose(c_triton, c_torch, rtol=1e-3, atol=1e-3)
    print(f"Results match: {is_correct}")

In [None]:
if GPU_AVAILABLE:
    # Benchmark
    def benchmark_matmul(fn, M, N, K, warmup=10, repeat=100):
        """Benchmark a matmul function."""
        a = torch.randn(M, K, device='cuda', dtype=torch.float32)
        b = torch.randn(K, N, device='cuda', dtype=torch.float32)
        
        # Warmup
        for _ in range(warmup):
            _ = fn(a, b)
        torch.cuda.synchronize()
        
        # Benchmark
        start = time.perf_counter()
        for _ in range(repeat):
            c = fn(a, b)
        torch.cuda.synchronize()
        elapsed = (time.perf_counter() - start) / repeat
        
        gflops = calculate_matmul_gflops(M, N, K, elapsed)
        return elapsed * 1000, gflops
    
    print("Benchmarking matmul implementations...")
    print()
    print(f"{'Size':>12} {'Triton (ms)':>12} {'Triton GFLOPS':>14} {'PyTorch (ms)':>13} {'Speedup':>8}")
    print("-" * 65)
    
    for size in [512, 1024, 2048, 4096]:
        triton_time, triton_gflops = benchmark_matmul(matmul_triton, size, size, size)
        torch_time, torch_gflops = benchmark_matmul(lambda a, b: a @ b, size, size, size)
        
        speedup = torch_time / triton_time
        print(f"{size}x{size}:  {triton_time:>10.2f}  {triton_gflops:>14.0f}  {torch_time:>12.2f}  {speedup:>7.2f}x")

### Analyzing the Results

Your kernel likely achieves:
- ~500-2000 GFLOPS depending on GPU and tile sizes
- Competitive with PyTorch for medium-sized matrices
- PyTorch may be faster due to more tuning

**To go faster, we would need:**
- Better tile sizes (autotuning)
- Tensor Cores (requires FP16/BF16)
- More sophisticated memory access patterns
- Pipelining (overlap load and compute)

These are topics for Week 2!

---
## Step 5: Verify - Understanding Check (10 min)

In [None]:
# Q1: Why do we loop over K in tiles?
print("Q1: Loop over K dimension in tiles because:")
print("    - Full K dimension doesn't fit in shared memory/registers")
print("    - Each iteration loads small tiles of A and B")
print("    - Partial products are accumulated in fast registers")
print("    - Result: global memory traffic is reduced by factor of TILE_SIZE")

In [None]:
# Q2: What does tl.dot() do?
print("Q2: tl.dot(a, b) computes tile-level matrix multiplication")
print("    - Input: two 2D tile tensors")
print("    - Output: matmul result")
print("    - On modern GPUs: compiles to Tensor Core instructions")
print("    - This is where most of the FLOPs happen!")

In [None]:
# Q3: Why use 2D grid?
print("Q3: 2D grid maps naturally to output matrix:")
print("    - pid_m = row tile index")
print("    - pid_n = column tile index")
print("    - Each program computes one TILE_M x TILE_N output tile")
print("    - Total programs = (M/TILE_M) x (N/TILE_N)")

In [None]:
# Q4: Where does data reuse happen?
print("Q4: Data reuse in the K-loop:")
print("    - Each tile of A is loaded once, used TILE_N times")
print("    - Each tile of B is loaded once, used TILE_M times")
print("    - Reuse happens within the tl.dot() operation")
print("    - Accumulated result stays in registers until loop ends")

---
## Week 1 Summary

### What We Learned

| Day | Topic | Key Takeaway |
|-----|-------|-------------|
| 1 | NumPy Baseline | GFLOPS measurement, CPU performance |
| 2 | CuPy | 50-100x speedup with GPU |
| 3 | GPU Architecture | SMs, warps, SIMT model |
| 4 | First Kernel | Index arithmetic, program IDs |
| 5 | Memory Hierarchy | Coalescing, arithmetic intensity |
| 6 | Tiling | Data reuse, shared memory |
| 7 | Fast Matmul | Putting it all together |

### Performance Journey

| Implementation | GFLOPS | Improvement |
|---------------|--------|------------|
| Naive Python | 0.001 | Baseline |
| NumPy | 50-100 | 50,000x |
| CuPy | 500-2000 | 10-20x over NumPy |
| Our Triton kernel | 500-2000 | Competitive! |

### What's Next (Week 2)

- **Profiling** - Use Nsight Compute to find bottlenecks
- **Autotuning** - Find optimal tile sizes automatically
- **Tensor Cores** - FP16/BF16 for 2-4x more FLOPS
- **Memory optimizations** - Bank conflicts, async copies
- **Target:** 80%+ of theoretical peak performance

---
## Congratulations!

You've written your first fast GPU kernel from scratch! You now understand:
- Why GPUs are fast (massive parallelism)
- How to think about GPU programming (tiles, blocks, warps)
- Why memory matters more than compute for most kernels
- How to implement tiled algorithms for data reuse

This foundation will serve you throughout your GPU programming journey.

**Next:** Week 2 - The Memory Game (optimization deep dive)