# Week 2, Day 4: Software Pipelining

**Time:** ~1 hour

**Goal:** Learn to hide memory latency by overlapping loads with computation.

## The Challenge

Memory loads take ~400 cycles. Compute takes ~4 cycles. If we wait for loads, we waste 99% of potential compute!

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench

---
## Step 1: The Challenge (5 min)

### Sequential Execution

```
Time →
┌─────────┐     ┌─────────┐     ┌─────────┐
│ Load A  │     │ Load B  │     │ Compute │     (Repeat)
└─────────┘     └─────────┘     └─────────┘
   400 cy          400 cy          4 cy

Total: 804 cycles per iteration
Compute utilization: 4/804 = 0.5%
```

### Pipelined Execution

```
Time →
┌─────────┬─────────┬─────────┬─────────┬──────
│ Load A₀ │ Load A₁ │ Load A₂ │ Load A₃ │ ...
├─────────┼─────────┼─────────┼─────────┼──────
│         │ Load B₀ │ Load B₁ │ Load B₂ │ ...
├─────────┼─────────┼─────────┼─────────┼──────
│         │         │Compute₀ │Compute₁ │ ...
└─────────┴─────────┴─────────┴─────────┴──────

After warmup: All stages run in parallel!
```

---
## Step 2: Explore (15 min)

### Double Buffering

The simplest pipelining technique: use 2 buffers.

```
While computing with Buffer A:
  → Load next data into Buffer B

While computing with Buffer B:
  → Load next data into Buffer A
```

This requires:
1. **Async memory operations** (loads that don't block)
2. **Synchronization** (ensure load completes before use)
3. **Multiple buffers** (to hold data for different stages)

In [None]:
@triton.jit
def matmul_no_pipeline(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Matmul without pipelining - synchronous loads."""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        # Synchronous loads - kernel waits for these
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        
        # Compute
        acc += tl.dot(a, b)
        
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

---
## Step 3: The Concept (10 min)

### Pipelining in Triton

Triton provides `num_stages` parameter for automatic pipelining:

```python
@triton.jit
def kernel(..., num_stages=2):  # 2-stage pipeline (double buffer)
    ...
```

When `num_stages > 1`, Triton automatically:
1. Uses async copy operations
2. Manages multiple SMEM buffers
3. Inserts proper barriers

### Stage Count Tradeoffs

| Stages | SMEM Usage | Latency Hiding | Complexity |
|--------|------------|----------------|------------|
| 1 | Minimum | None | Simple |
| 2 | 2x | Moderate | Medium |
| 3 | 3x | Good | Medium |
| 4+ | 4x+ | Excellent | Higher |

In [None]:
@triton.jit
def matmul_pipelined(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    """Matmul with software pipelining.
    
    Triton handles the pipelining when we use num_stages in the kernel config.
    This version shows the concept - actual pipelining is in the autotuner.
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    # Main loop - Triton compiler will pipeline this
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        
        acc += tl.dot(a, b)
        
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

In [None]:
def matmul_wrapper(a, b, kernel_fn, BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, num_stages=1):
    """Wrapper for matmul kernels."""
    M, K = a.shape
    K2, N = b.shape
    assert K == K2
    
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)
    
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    
    kernel_fn[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        NUM_STAGES=num_stages,
    )
    return c

def benchmark_pipeline_stages(M, N, K):
    """Benchmark matmul with different pipeline depths."""
    a = torch.randn(M, K, device='cuda', dtype=torch.float16)
    b = torch.randn(K, N, device='cuda', dtype=torch.float16)
    
    results = {}
    
    for stages in [1, 2, 3, 4]:
        ms = do_bench(lambda s=stages: matmul_wrapper(a, b, matmul_pipelined, num_stages=s))
        flops = 2 * M * N * K
        tflops = flops / (ms * 1e-3) / 1e12
        results[stages] = {'ms': ms, 'tflops': tflops}
    
    # PyTorch baseline
    ms_torch = do_bench(lambda: torch.mm(a, b))
    tflops_torch = flops / (ms_torch * 1e-3) / 1e12
    results['torch'] = {'ms': ms_torch, 'tflops': tflops_torch}
    
    return results

---
## Step 4: Code It (30 min)

### Autotuning Pipeline Depth

In [None]:
# Triton autotuning configs
configs = [
    triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=1),
    triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=2),
    triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3),
    triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=2),
    triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3),
]

@triton.autotune(configs=configs, key=['M', 'N', 'K'])
@triton.jit
def matmul_autotuned(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Autotuned matmul - Triton finds best config including pipeline depth."""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

In [None]:
def benchmark_autotuned(M, N, K):
    """Benchmark autotuned kernel."""
    a = torch.randn(M, K, device='cuda', dtype=torch.float16)
    b = torch.randn(K, N, device='cuda', dtype=torch.float16)
    c = torch.empty(M, N, device='cuda', dtype=torch.float32)
    
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
    
    # Warmup + autotune
    matmul_autotuned[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    
    ms = do_bench(lambda: matmul_autotuned[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    ))
    
    flops = 2 * M * N * K
    tflops = flops / (ms * 1e-3) / 1e12
    
    # Compare to PyTorch
    ms_torch = do_bench(lambda: torch.mm(a, b))
    tflops_torch = flops / (ms_torch * 1e-3) / 1e12
    
    return {
        'triton_ms': ms,
        'triton_tflops': tflops,
        'torch_ms': ms_torch,
        'torch_tflops': tflops_torch,
        'efficiency': tflops / tflops_torch * 100,
    }

print("Autotuned Matmul Performance")
print("=" * 60)

for size in [1024, 2048, 4096]:
    result = benchmark_autotuned(size, size, size)
    print(f"\nSize {size}x{size}:")
    print(f"  Triton:    {result['triton_ms']:.3f} ms, {result['triton_tflops']:.1f} TFLOPS")
    print(f"  PyTorch:   {result['torch_ms']:.3f} ms, {result['torch_tflops']:.1f} TFLOPS")
    print(f"  Efficiency: {result['efficiency']:.1f}%")

### Producer-Consumer Pattern

In CUDA, pipelining uses explicit barriers:

```cpp
// Producer (memory load)
cp.async.cg.shared.global [smem], [gmem];  // Async copy
cp.async.commit_group;  // Commit this batch

// Consumer (compute)
cp.async.wait_group N;  // Wait for N groups to complete
// Now safe to use data
```

Triton abstracts this - the compiler generates proper barriers.

---
## Step 5: Verify (10 min)

### Quiz

**Q1:** What's the minimum number of pipeline stages needed to hide memory latency?

A) 1 (no pipelining)  
B) latency / compute_time  
C) 2 (double buffering is always enough)  
D) As many as SMEM allows

**Q2:** What's the tradeoff of more pipeline stages?

A) More SMEM usage  
B) More register pressure  
C) Longer warmup/drain  
D) All of the above

**Q3:** Why might 3-stage pipelining be worse than 2-stage on some GPUs?

In [None]:
print("Quiz Answers")
print("=" * 50)
print()
print("Q1: B) latency / compute_time")
print("    To fully hide 400-cycle latency with 100-cycle compute,")
print("    you need at least 4 stages (400/100 = 4).")
print()
print("Q2: D) All of the above")
print("    More stages = more buffers = more SMEM and registers.")
print("    Also longer warmup (fill pipeline) and drain (empty it).")
print()
print("Q3: SMEM limitation")
print("    If 3 buffers exceed SMEM capacity, occupancy drops.")
print("    Lower occupancy can hurt more than pipelining helps.")
print("    Always benchmark - optimal stages depend on kernel size and GPU.")

---
## Summary

### Key Takeaways

1. **Memory latency (~400 cycles) dominates compute (~4 cycles)**
2. **Pipelining overlaps loads with compute** using multiple buffers
3. **Triton's `num_stages`** controls pipeline depth automatically
4. **More stages = better latency hiding** but more SMEM usage
5. **Autotune to find optimal** - it depends on kernel and GPU

### Pipeline Depth Guidelines

| GPU | Typical Optimal Stages |
|-----|------------------------|
| Ampere (A100) | 3-4 stages |
| Hopper (H100) | 2-3 stages (TMA helps!) |

### Tomorrow: TMA

On Hopper GPUs, the Tensor Memory Accelerator (TMA) does async copies with ZERO SM involvement - even better than software pipelining!