# Week 2, Day 2: Coalescing Experiments

**Time:** ~1 hour

**Goal:** Understand memory coalescing through hands-on experiments and measure its performance impact.

## The Challenge

Two kernels doing the same work can have **10x different performance** based solely on memory access patterns.

Today we'll prove it.

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
import matplotlib.pyplot as plt
import numpy as np

---
## Step 1: The Challenge (5 min)

Let's create two kernels that copy data from one array to another.
Same operation, different access patterns.

In [None]:
@triton.jit
def copy_coalesced(src_ptr, dst_ptr, N, BLOCK_SIZE: tl.constexpr):
    """Coalesced copy: adjacent threads access adjacent memory.
    
    Thread 0: src[0], Thread 1: src[1], Thread 2: src[2], ...
    """
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    
    # Adjacent threads load adjacent elements
    data = tl.load(src_ptr + offsets, mask=mask)
    tl.store(dst_ptr + offsets, data, mask=mask)


@triton.jit  
def copy_strided(src_ptr, dst_ptr, N, STRIDE: tl.constexpr, BLOCK_SIZE: tl.constexpr):
    """Strided copy: threads access memory with gaps.
    
    Thread 0: src[0], Thread 1: src[STRIDE], Thread 2: src[2*STRIDE], ...
    """
    pid = tl.program_id(0)
    # Each thread handles elements STRIDE apart within its chunk
    thread_ids = tl.arange(0, BLOCK_SIZE)
    
    # Strided access pattern
    offsets = pid * BLOCK_SIZE * STRIDE + thread_ids * STRIDE
    mask = offsets < N
    
    data = tl.load(src_ptr + offsets, mask=mask)
    tl.store(dst_ptr + offsets, data, mask=mask)

In [None]:
def benchmark_copy_patterns(N, BLOCK_SIZE=256):
    """Benchmark coalesced vs strided memory access."""
    src = torch.randn(N, device='cuda', dtype=torch.float32)
    dst_coal = torch.empty_like(src)
    dst_stride = torch.empty_like(src)
    
    # Coalesced copy
    grid_coal = (triton.cdiv(N, BLOCK_SIZE),)
    ms_coal = do_bench(lambda: copy_coalesced[grid_coal](src, dst_coal, N, BLOCK_SIZE=BLOCK_SIZE))
    
    # Strided copies with different strides
    results = {'coalesced': ms_coal}
    
    for stride in [2, 4, 8, 16, 32]:
        # Adjust grid for strided access
        grid_stride = (triton.cdiv(N, BLOCK_SIZE * stride),)
        if grid_stride[0] > 0:
            ms_stride = do_bench(lambda s=stride: copy_strided[grid_stride](
                src, dst_stride, N, STRIDE=s, BLOCK_SIZE=BLOCK_SIZE))
            results[f'stride_{stride}'] = ms_stride
    
    return results

# Run benchmark
N = 64 * 1024 * 1024  # 64M elements = 256 MB
results = benchmark_copy_patterns(N)

print("Memory Copy Performance")
print("=" * 50)
bytes_total = N * 4 * 2  # read + write, FP32

for pattern, ms in results.items():
    bw_gbs = bytes_total / (ms * 1e-3) / 1e9
    print(f"{pattern:<15}: {ms:.3f} ms, {bw_gbs:.1f} GB/s")

# Calculate slowdown
print("\nSlowdown vs coalesced:")
for pattern, ms in results.items():
    if pattern != 'coalesced':
        slowdown = ms / results['coalesced']
        print(f"  {pattern}: {slowdown:.1f}x slower")

---
## Step 2: Explore (15 min)

### What's Happening Under the Hood?

GPU memory is accessed in **cache lines** (typically 128 bytes on modern GPUs).

```
COALESCED ACCESS:
Thread 0  Thread 1  Thread 2  Thread 3  ...  Thread 31
   ↓         ↓         ↓         ↓              ↓
┌──────────────────────────────────────────────────┐
│  elem 0 │ elem 1 │ elem 2 │ elem 3 │ ... │ elem 31 │  ← ONE 128-byte transaction
└──────────────────────────────────────────────────┘

STRIDED ACCESS (stride=32):
Thread 0       Thread 1       Thread 2       ...
   ↓              ↓              ↓
┌────────┐    ┌────────┐    ┌────────┐
│ elem 0 │    │ elem 32│    │ elem 64│     ← 32 SEPARATE transactions!
└────────┘    └────────┘    └────────┘
```

In [None]:
def visualize_access_pattern():
    """Visualize coalesced vs strided access patterns."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    
    # Coalesced
    ax = axes[0]
    threads = np.arange(32)
    addresses = threads  # Adjacent
    
    ax.bar(threads, np.ones(32), width=0.8, color='green', alpha=0.7)
    ax.set_xlabel('Thread ID')
    ax.set_ylabel('Memory Address (relative)')
    ax.set_title('Coalesced: 1 Transaction')
    ax.set_ylim(0, 35)
    
    # Add cache line box
    ax.axhspan(0, 1, alpha=0.3, color='blue', label='128-byte cache line')
    
    # Strided
    ax = axes[1]
    stride = 4
    addresses = threads * stride
    
    colors = plt.cm.Set3(np.linspace(0, 1, 32))
    for i, (t, a) in enumerate(zip(threads, addresses)):
        ax.bar(t, a + 1, width=0.8, color=colors[a % 8], alpha=0.7)
    
    ax.set_xlabel('Thread ID')
    ax.set_ylabel('Memory Address (relative)')
    ax.set_title(f'Strided (stride={stride}): Multiple Transactions')
    ax.set_ylim(0, 130)
    
    # Add cache line boxes
    for i in range(4):
        ax.axhspan(i*32, (i+1)*32, alpha=0.2, color=f'C{i}')
    
    plt.tight_layout()
    plt.savefig('coalescing_viz.png', dpi=100, bbox_inches='tight')
    plt.show()
    print("Saved visualization to coalescing_viz.png")

visualize_access_pattern()

### The Math of Coalescing

For a warp (32 threads) accessing FP32 data:

| Access Pattern | Bytes Needed | Transactions | Effective BW |
|---------------|--------------|--------------|---------------|
| Coalesced | 128 bytes | 1 | 100% |
| Stride 2 | 128 bytes | 2 | 50% |
| Stride 4 | 128 bytes | 4 | 25% |
| Stride 32 | 128 bytes | 32 | 3.1% |

**Key insight:** Strided access doesn't load extra data, it just uses more transactions, each partially utilized.

---
## Step 3: The Concept (10 min)

### Rules for Coalesced Access

1. **Adjacent threads should access adjacent memory locations**
   ```python
   # GOOD: thread i accesses element i
   data[thread_id]
   
   # BAD: thread i accesses element i*stride
   data[thread_id * 32]
   ```

2. **Alignment matters**
   - Start address should be aligned to transaction size (32/64/128 bytes)
   - Misaligned access may require extra transactions

3. **Row-major vs Column-major**
   - For 2D arrays, access along the fastest-changing dimension
   - C/NumPy: row-major (access along rows)
   - Fortran: column-major (access along columns)

### Matrix Access Patterns

In [None]:
@triton.jit
def matrix_sum_rows(mat_ptr, out_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    """Sum along rows - COALESCED for row-major storage.
    
    Each thread handles one row, accesses consecutive columns.
    Adjacent threads access adjacent memory (good!).
    """
    row = tl.program_id(0)
    
    if row < M:
        total = tl.zeros((1,), dtype=tl.float32)
        for col_start in range(0, N, BLOCK_SIZE):
            cols = col_start + tl.arange(0, BLOCK_SIZE)
            mask = cols < N
            # Row-major: mat[row, col] = mat_ptr[row * N + col]
            vals = tl.load(mat_ptr + row * N + cols, mask=mask, other=0.0)
            total += tl.sum(vals)
        tl.store(out_ptr + row, total)


@triton.jit
def matrix_sum_cols(mat_ptr, out_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    """Sum along columns - NOT COALESCED for row-major storage.
    
    Each thread handles one column, accesses consecutive rows.
    Adjacent threads access memory N elements apart (bad!).
    """
    col = tl.program_id(0)
    
    if col < N:
        total = tl.zeros((1,), dtype=tl.float32)
        for row_start in range(0, M, BLOCK_SIZE):
            rows = row_start + tl.arange(0, BLOCK_SIZE)
            mask = rows < M
            # Row-major: mat[row, col] = mat_ptr[row * N + col]
            # This is strided access! Stride = N
            vals = tl.load(mat_ptr + rows * N + col, mask=mask, other=0.0)
            total += tl.sum(vals)
        tl.store(out_ptr + col, total)

In [None]:
def benchmark_matrix_reduction(M, N):
    """Compare row vs column reduction performance."""
    mat = torch.randn(M, N, device='cuda', dtype=torch.float32)
    out_rows = torch.empty(M, device='cuda', dtype=torch.float32)
    out_cols = torch.empty(N, device='cuda', dtype=torch.float32)
    
    BLOCK_SIZE = 256
    
    # Row reduction (coalesced)
    ms_rows = do_bench(lambda: matrix_sum_rows[(M,)](mat, out_rows, M, N, BLOCK_SIZE=BLOCK_SIZE))
    
    # Column reduction (strided)
    ms_cols = do_bench(lambda: matrix_sum_cols[(N,)](mat, out_cols, M, N, BLOCK_SIZE=BLOCK_SIZE))
    
    # Verify correctness
    expected_rows = mat.sum(dim=1)
    expected_cols = mat.sum(dim=0)
    
    return {
        'rows_ms': ms_rows,
        'cols_ms': ms_cols,
        'slowdown': ms_cols / ms_rows,
        'rows_correct': torch.allclose(out_rows, expected_rows, rtol=1e-3),
        'cols_correct': torch.allclose(out_cols, expected_cols, rtol=1e-3),
    }

# Test with square matrix
for size in [1024, 2048, 4096]:
    result = benchmark_matrix_reduction(size, size)
    print(f"Matrix {size}x{size}:")
    print(f"  Row reduction (coalesced): {result['rows_ms']:.3f} ms")
    print(f"  Col reduction (strided):   {result['cols_ms']:.3f} ms")
    print(f"  Slowdown: {result['slowdown']:.1f}x")
    print(f"  Correct: rows={result['rows_correct']}, cols={result['cols_correct']}")
    print()

---
## Step 4: Code It (30 min)

### Exercise: Fix the Strided Access

The column reduction is slow because of strided access. How can we fix it?

**Strategy:** Use shared memory to transpose the data, then reduce with coalesced access.

In [None]:
@triton.jit
def matrix_sum_cols_tiled(
    mat_ptr, out_ptr, M, N,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
    """Column sum with tiled access for better coalescing.
    
    Strategy: Load BLOCK_M x BLOCK_N tiles, accumulate partial sums per column.
    Within each tile, we get coalesced access.
    """
    # Each program handles BLOCK_N columns
    col_block = tl.program_id(0)
    col_start = col_block * BLOCK_N
    
    # Accumulators for BLOCK_N columns
    acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
    
    # Process matrix in row tiles
    for row_start in range(0, M, BLOCK_M):
        # Load a BLOCK_M x BLOCK_N tile
        rows = row_start + tl.arange(0, BLOCK_M)
        cols = col_start + tl.arange(0, BLOCK_N)
        
        # Create masks
        row_mask = rows[:, None] < M
        col_mask = cols[None, :] < N
        mask = row_mask & col_mask
        
        # Load tile (coalesced within each row)
        tile = tl.load(
            mat_ptr + rows[:, None] * N + cols[None, :],
            mask=mask,
            other=0.0
        )
        
        # Sum along rows (within this tile)
        acc += tl.sum(tile, axis=0)
    
    # Store results
    cols = col_start + tl.arange(0, BLOCK_N)
    mask = cols < N
    tl.store(out_ptr + cols, acc, mask=mask)

In [None]:
def benchmark_col_reduction_variants(M, N):
    """Compare different column reduction implementations."""
    mat = torch.randn(M, N, device='cuda', dtype=torch.float32)
    out_naive = torch.empty(N, device='cuda', dtype=torch.float32)
    out_tiled = torch.empty(N, device='cuda', dtype=torch.float32)
    
    BLOCK_SIZE = 256
    BLOCK_M, BLOCK_N = 32, 32
    
    # Naive (strided)
    ms_naive = do_bench(lambda: matrix_sum_cols[(N,)](mat, out_naive, M, N, BLOCK_SIZE=BLOCK_SIZE))
    
    # Tiled (better coalescing)
    grid_tiled = (triton.cdiv(N, BLOCK_N),)
    ms_tiled = do_bench(lambda: matrix_sum_cols_tiled[grid_tiled](mat, out_tiled, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N))
    
    # PyTorch baseline
    ms_torch = do_bench(lambda: mat.sum(dim=0))
    
    # Verify
    expected = mat.sum(dim=0)
    
    return {
        'naive_ms': ms_naive,
        'tiled_ms': ms_tiled,
        'torch_ms': ms_torch,
        'speedup': ms_naive / ms_tiled,
        'naive_correct': torch.allclose(out_naive, expected, rtol=1e-3),
        'tiled_correct': torch.allclose(out_tiled, expected, rtol=1e-3),
    }

# Benchmark
print("Column Reduction Comparison")
print("=" * 60)

for size in [1024, 2048, 4096]:
    result = benchmark_col_reduction_variants(size, size)
    print(f"\nMatrix {size}x{size}:")
    print(f"  Naive (strided):  {result['naive_ms']:.3f} ms")
    print(f"  Tiled:            {result['tiled_ms']:.3f} ms")
    print(f"  PyTorch:          {result['torch_ms']:.3f} ms")
    print(f"  Tiled speedup:    {result['speedup']:.1f}x vs naive")
    print(f"  Correct:          naive={result['naive_correct']}, tiled={result['tiled_correct']}")

### Real-World Example: Matrix Transpose

Transpose is the classic coalescing problem:
- Read row-major (coalesced reads)
- Write column-major (strided writes) OR
- Read column-major (strided reads)
- Write row-major (coalesced writes)

**Solution:** Use shared memory as a staging area!

In [None]:
@triton.jit
def transpose_naive(src_ptr, dst_ptr, M, N, BLOCK: tl.constexpr):
    """Naive transpose - strided writes."""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK + tl.arange(0, BLOCK)
    offs_n = pid_n * BLOCK + tl.arange(0, BLOCK)
    
    # Load tile (coalesced for row-major)
    src_offs = offs_m[:, None] * N + offs_n[None, :]
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tile = tl.load(src_ptr + src_offs, mask=mask)
    
    # Store transposed (strided writes!)
    dst_offs = offs_n[:, None] * M + offs_m[None, :]  # Note: swapped indices
    mask_t = (offs_n[:, None] < N) & (offs_m[None, :] < M)
    tl.store(dst_ptr + dst_offs, tl.trans(tile), mask=mask_t)


@triton.jit
def transpose_smem(src_ptr, dst_ptr, M, N, BLOCK: tl.constexpr):
    """Transpose using shared memory for coalesced writes.
    
    1. Load tile from global mem (coalesced)
    2. Store to shared mem
    3. Load transposed from shared mem 
    4. Store to global mem (coalesced)
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK + tl.arange(0, BLOCK)
    offs_n = pid_n * BLOCK + tl.arange(0, BLOCK)
    
    # Load tile (coalesced reads from row-major)
    src_offs = offs_m[:, None] * N + offs_n[None, :]
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tile = tl.load(src_ptr + src_offs, mask=mask)
    
    # Transpose in registers (small tiles) or would use shared memory for larger
    tile_t = tl.trans(tile)
    
    # Store transposed tile (coalesced writes to row-major output)
    # Output has shape (N, M), so we write row pid_n, columns offs_m
    dst_offs = offs_n[:, None] * M + offs_m[None, :]
    mask_t = (offs_n[:, None] < N) & (offs_m[None, :] < M)
    tl.store(dst_ptr + dst_offs, tile_t, mask=mask_t)

In [None]:
def benchmark_transpose(M, N):
    """Benchmark transpose implementations."""
    src = torch.randn(M, N, device='cuda', dtype=torch.float32)
    dst_naive = torch.empty(N, M, device='cuda', dtype=torch.float32)
    dst_smem = torch.empty(N, M, device='cuda', dtype=torch.float32)
    
    BLOCK = 32
    grid = (triton.cdiv(M, BLOCK), triton.cdiv(N, BLOCK))
    
    ms_naive = do_bench(lambda: transpose_naive[grid](src, dst_naive, M, N, BLOCK=BLOCK))
    ms_smem = do_bench(lambda: transpose_smem[grid](src, dst_smem, M, N, BLOCK=BLOCK))
    ms_torch = do_bench(lambda: src.T.contiguous())
    
    # Calculate bandwidth
    bytes_moved = M * N * 4 * 2  # read + write
    
    return {
        'naive_ms': ms_naive,
        'smem_ms': ms_smem,
        'torch_ms': ms_torch,
        'naive_gbps': bytes_moved / (ms_naive * 1e-3) / 1e9,
        'smem_gbps': bytes_moved / (ms_smem * 1e-3) / 1e9,
        'correct': torch.allclose(dst_smem, src.T),
    }

print("Transpose Benchmark")
print("=" * 60)

for size in [1024, 2048, 4096]:
    result = benchmark_transpose(size, size)
    print(f"\nMatrix {size}x{size}:")
    print(f"  Naive:   {result['naive_ms']:.3f} ms ({result['naive_gbps']:.0f} GB/s)")
    print(f"  SMEM:    {result['smem_ms']:.3f} ms ({result['smem_gbps']:.0f} GB/s)")
    print(f"  PyTorch: {result['torch_ms']:.3f} ms")
    print(f"  Speedup: {result['naive_ms'] / result['smem_ms']:.1f}x")

---
## Step 5: Verify (10 min)

### Quiz

**Q1:** Which access pattern is coalesced for a row-major 2D array?
```python
A) data[row, thread_id]      # Varying column
B) data[thread_id, col]      # Varying row  
C) data[thread_id, thread_id]  # Diagonal
```

**Q2:** A kernel reads 128 bytes per warp but uses 4 memory transactions. What's the efficiency?

**Q3:** You have a column reduction that's 8x slower than expected. What's your first optimization strategy?

In [None]:
# Answers
print("Quiz Answers")
print("=" * 50)
print()
print("Q1: A) data[row, thread_id]")
print("    Row-major stores consecutive columns contiguously.")
print("    Adjacent threads accessing adjacent columns = coalesced.")
print()
print("Q2: 25% efficiency")
print("    4 transactions of 128 bytes each = 512 bytes transferred")
print("    Only 128 bytes actually needed → 128/512 = 25%")
print()
print("Q3: Tile the reduction")
print("    Load tiles with coalesced access, reduce within tiles.")
print("    Or transpose first, then do row reduction.")

---
## Summary

### Key Takeaways

1. **Coalescing = adjacent threads access adjacent memory**
2. **Strided access can be 10-30x slower** than coalesced
3. **Use shared memory** to stage data for coalesced writes
4. **For 2D arrays:** access along the contiguous dimension (columns for row-major)
5. **Tiling helps** by allowing coalesced access within tiles

### Patterns to Remember

| Operation | Good Pattern | Bad Pattern |
|-----------|-------------|-------------|
| Vector access | `data[tid]` | `data[tid * stride]` |
| Row-major matrix | `mat[row, tid]` | `mat[tid, col]` |
| Reduction | Along contiguous dim | Along strided dim |
| Transpose | Via shared memory | Direct strided write |

### Tomorrow: Bank Conflicts

We used shared memory to fix coalescing, but shared memory has its own access pattern issues: **bank conflicts**.

In [None]:
# Cleanup
import os
if os.path.exists('coalescing_viz.png'):
    os.remove('coalescing_viz.png')