# Week 2, Day 3: Bank Conflict Laboratory

**Time:** ~1 hour

**Goal:** Understand shared memory bank conflicts and how to avoid them.

## The Challenge

Yesterday we used shared memory to fix coalescing issues. But shared memory has its own trap: **bank conflicts**.

A kernel with bank conflicts can be **32x slower** than one without.

In [None]:
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
import numpy as np

---
## Step 1: The Challenge (5 min)

Shared memory is organized into **32 banks**. Each bank can serve one request per cycle.

When multiple threads in a warp access **different addresses** in the **same bank**, they must serialize → **bank conflict**.

In [None]:
@triton.jit
def smem_access_stride1(out_ptr, BLOCK: tl.constexpr):
    """Stride-1 access: NO bank conflicts.
    
    Thread 0 → Bank 0, Thread 1 → Bank 1, ... Thread 31 → Bank 31
    """
    tid = tl.arange(0, BLOCK)
    
    # Allocate shared memory (in Triton, done implicitly via local arrays)
    # Stride-1 access pattern
    smem_idx = tid  # Thread i accesses element i
    
    # Simulate work: multiple accesses
    acc = tl.zeros((BLOCK,), dtype=tl.float32)
    for _ in range(1000):
        acc += smem_idx.to(tl.float32)  # This would be smem[tid]
    
    tl.store(out_ptr + tid, acc)


@triton.jit
def smem_access_stride32(out_ptr, BLOCK: tl.constexpr):
    """Stride-32 access: MAXIMUM bank conflicts (32-way).
    
    All 32 threads in a warp access the SAME bank!
    Thread 0 → Bank 0, Thread 1 → Bank 0, ... Thread 31 → Bank 0
    """
    tid = tl.arange(0, BLOCK)
    
    # Stride-32 access pattern
    smem_idx = tid * 32  # Thread i accesses element i*32 → all same bank!
    
    # Simulate work
    acc = tl.zeros((BLOCK,), dtype=tl.float32)
    for _ in range(1000):
        acc += (smem_idx % 1024).to(tl.float32)  # Simulating smem[tid*32]
    
    tl.store(out_ptr + tid, acc)

---
## Step 2: Explore (15 min)

### How Banks Work

```
Shared Memory Organization (32 banks, 4-byte words):

Address:    0    4    8   12   16  ...  124  128  132  ...
Bank:       0    1    2    3    4  ...   31    0    1  ...

Bank number = (address / 4) % 32
```

### Access Patterns

| Pattern | Bank Access | Conflicts |
|---------|------------|----------|
| `smem[tid]` | Each thread → different bank | 0 |
| `smem[tid * 2]` | 2 threads per bank | 2-way |
| `smem[tid * 32]` | All threads → same bank | 32-way |
| `smem[tid * 33]` | Each thread → different bank | 0 |

**Exception:** If ALL threads access the SAME address, it's a **broadcast** (no conflict).

In [None]:
def calculate_bank(address, word_size=4):
    """Calculate which bank an address maps to."""
    return (address // word_size) % 32

def analyze_access_pattern(stride, num_threads=32, word_size=4):
    """Analyze bank conflicts for a given stride pattern."""
    addresses = [i * stride * word_size for i in range(num_threads)]
    banks = [calculate_bank(addr, word_size) for addr in addresses]
    
    # Count accesses per bank
    bank_counts = {}
    for b in banks:
        bank_counts[b] = bank_counts.get(b, 0) + 1
    
    max_conflicts = max(bank_counts.values())
    
    return {
        'banks': banks,
        'unique_banks': len(set(banks)),
        'max_way_conflict': max_conflicts,
        'serialization_factor': max_conflicts,
    }

# Analyze different strides
print("Bank Conflict Analysis")
print("=" * 60)
print(f"{'Stride':<10} {'Unique Banks':<15} {'Max Conflict':<15} {'Slowdown':<10}")
print("-" * 60)

for stride in [1, 2, 4, 8, 16, 32, 33, 64, 65]:
    result = analyze_access_pattern(stride)
    slowdown = f"{result['serialization_factor']}x" if result['serialization_factor'] > 1 else "None"
    print(f"{stride:<10} {result['unique_banks']:<15} {result['max_way_conflict']}-way{'':<10} {slowdown:<10}")

In [None]:
def visualize_bank_conflicts(stride):
    """Visualize which bank each thread accesses."""
    print(f"\nStride = {stride}:")
    print("Thread:  ", end="")
    for t in range(32):
        print(f"{t:3d}", end="")
    print()
    
    print("Bank:    ", end="")
    banks = [(t * stride) % 32 for t in range(32)]
    for b in banks:
        print(f"{b:3d}", end="")
    print()
    
    # Show conflicts
    bank_counts = {}
    for b in banks:
        bank_counts[b] = bank_counts.get(b, 0) + 1
    
    max_conflict = max(bank_counts.values())
    if max_conflict > 1:
        print(f"→ {max_conflict}-way bank conflict!")
    else:
        print("→ No conflicts (optimal)")

visualize_bank_conflicts(1)
visualize_bank_conflicts(2)
visualize_bank_conflicts(32)
visualize_bank_conflicts(33)

---
## Step 3: The Concept (10 min)

### The Padding Trick

To avoid bank conflicts with power-of-2 strides, add **padding** to break the pattern.

```cpp
// Without padding: stride-32 access → 32-way conflict
__shared__ float smem[32][32];  // smem[row][col]
// Thread t accessing smem[t][0] → all hit bank 0!

// With padding: stride-33 access → no conflict
__shared__ float smem[32][33];  // Extra column
// Thread t accessing smem[t][0] → each hits different bank!
```

### Why Stride-33 Works

```
Thread 0: address 0 → bank 0
Thread 1: address 33*4 = 132 → bank (132/4) % 32 = 33 % 32 = 1
Thread 2: address 66*4 = 264 → bank (264/4) % 32 = 66 % 32 = 2
...
Thread 31: address 1023*4 → bank 1023 % 32 = 31
```

Each thread hits a different bank!

In [None]:
@triton.jit
def transpose_bank_conflict(
    src_ptr, dst_ptr, M, N,
    BLOCK: tl.constexpr
):
    """Transpose with bank conflicts in shared memory.
    
    Load tile row-by-row, store column-by-column → bank conflicts!
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK + tl.arange(0, BLOCK)
    offs_n = pid_n * BLOCK + tl.arange(0, BLOCK)
    
    # Load tile (coalesced from global memory)
    src_offs = offs_m[:, None] * N + offs_n[None, :]
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tile = tl.load(src_ptr + src_offs, mask=mask)
    
    # Transpose and store
    # Reading column-major from the tile causes bank conflicts
    tile_t = tl.trans(tile)
    
    dst_offs = offs_n[:, None] * M + offs_m[None, :]
    mask_t = (offs_n[:, None] < N) & (offs_m[None, :] < M)
    tl.store(dst_ptr + dst_offs, tile_t, mask=mask_t)


@triton.jit
def transpose_no_bank_conflict(
    src_ptr, dst_ptr, M, N,
    BLOCK: tl.constexpr
):
    """Transpose avoiding bank conflicts.
    
    Triton's tl.trans handles this automatically for small tiles.
    For explicit control, we'd use padding.
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK + tl.arange(0, BLOCK)
    offs_n = pid_n * BLOCK + tl.arange(0, BLOCK)
    
    # Load tile
    src_offs = offs_m[:, None] * N + offs_n[None, :]
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tile = tl.load(src_ptr + src_offs, mask=mask)
    
    # Triton handles transpose efficiently
    tile_t = tl.trans(tile)
    
    dst_offs = offs_n[:, None] * M + offs_m[None, :]
    mask_t = (offs_n[:, None] < N) & (offs_m[None, :] < M)
    tl.store(dst_ptr + dst_offs, tile_t, mask=mask_t)

---
## Step 4: Code It (30 min)

### Exercise: Reduction with Bank Conflicts

Let's implement a parallel reduction and see how bank conflicts affect it.

In [None]:
@triton.jit
def reduce_sum_naive(x_ptr, out_ptr, N, BLOCK: tl.constexpr):
    """Naive reduction - potential bank conflicts.
    
    Classic parallel reduction pattern.
    """
    pid = tl.program_id(0)
    
    # Load block of data
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    x = tl.load(x_ptr + offs, mask=mask, other=0.0)
    
    # Reduce using Triton's built-in (optimized)
    total = tl.sum(x)
    
    # Store partial sum
    if pid == 0:
        # Only first program stores (simplified)
        tl.store(out_ptr, total)


@triton.jit
def reduce_sum_atomic(x_ptr, out_ptr, N, BLOCK: tl.constexpr):
    """Reduction with atomic adds for partial sums."""
    pid = tl.program_id(0)
    
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    x = tl.load(x_ptr + offs, mask=mask, other=0.0)
    
    # Local sum
    partial = tl.sum(x)
    
    # Atomic add to global output
    tl.atomic_add(out_ptr, partial)

In [None]:
def benchmark_reduction(N):
    """Benchmark different reduction approaches."""
    x = torch.randn(N, device='cuda', dtype=torch.float32)
    out_naive = torch.zeros(1, device='cuda', dtype=torch.float32)
    out_atomic = torch.zeros(1, device='cuda', dtype=torch.float32)
    
    BLOCK = 1024
    grid = (triton.cdiv(N, BLOCK),)
    
    # Note: naive version only gets first block's sum (simplified demo)
    ms_naive = do_bench(lambda: reduce_sum_naive[grid](x, out_naive, N, BLOCK=BLOCK))
    
    # Atomic version gets full sum
    ms_atomic = do_bench(lambda: reduce_sum_atomic[grid](x, out_atomic.zero_(), N, BLOCK=BLOCK))
    
    # PyTorch baseline
    ms_torch = do_bench(lambda: x.sum())
    
    return {
        'naive_ms': ms_naive,
        'atomic_ms': ms_atomic,
        'torch_ms': ms_torch,
    }

print("Reduction Benchmark")
print("=" * 50)

for exp in range(20, 28, 2):
    N = 2 ** exp
    result = benchmark_reduction(N)
    print(f"N = 2^{exp} ({N:,}):")
    print(f"  Naive:   {result['naive_ms']:.4f} ms")
    print(f"  Atomic:  {result['atomic_ms']:.4f} ms")
    print(f"  PyTorch: {result['torch_ms']:.4f} ms")

### Matrix Multiplication: The Bank Conflict Trap

In tiled matmul, we load A and B tiles into shared memory. The access pattern during the inner loop can cause conflicts.

```python
# Inner loop: C[i,j] += A[i,k] * B[k,j]
# If A is stored row-major in SMEM:
#   Thread (i,j) reads A[i, 0], A[i, 1], ..., A[i, K-1]
#   All threads in row i read from same addresses → broadcast (OK)
#
# If B is stored row-major in SMEM:
#   Thread (i,j) reads B[0, j], B[1, j], ..., B[K-1, j]
#   This is column access → depends on j spacing!
```

In [None]:
@triton.jit
def matmul_check_access(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Matmul with explicit access pattern analysis.
    
    Watch the access patterns:
    - A tile: BLOCK_M x BLOCK_K, accessed row-by-row (good)
    - B tile: BLOCK_K x BLOCK_N, accessed column-by-column (watch out!)
    """
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        # Load A tile - threads read along columns (K dimension)
        # For row-major A: adjacent K values are adjacent in memory (good!)
        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] + k < K), other=0.0)
        
        # Load B tile - threads read along rows (K dimension)
        # For row-major B: adjacent K values are stride-N apart
        # If we store B transposed in SMEM, we can avoid this
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) & (offs_n[None, :] < N), other=0.0)
        
        acc += tl.dot(a, b)
        
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=mask)

---
## Step 5: Verify (10 min)

### Quiz

**Q1:** How many banks does shared memory have on modern NVIDIA GPUs?

A) 16  
B) 32  
C) 64  
D) 128

**Q2:** Which access pattern causes the worst bank conflicts?

A) `smem[threadIdx.x]`  
B) `smem[threadIdx.x * 2]`  
C) `smem[threadIdx.x * 32]`  
D) `smem[threadIdx.x * 33]`

**Q3:** How do you fix a stride-32 bank conflict in a 2D shared memory array?

A) Use stride-1 access  
B) Add padding to make it stride-33  
C) Use global memory instead  
D) Reduce the number of threads

In [None]:
print("Quiz Answers")
print("=" * 50)
print()
print("Q1: B) 32 banks")
print("    All modern NVIDIA GPUs (Kepler onwards) have 32 banks.")
print()
print("Q2: C) smem[threadIdx.x * 32]")
print("    All 32 threads access the same bank → 32-way conflict.")
print("    Stride-33 (D) actually has NO conflicts!")
print()
print("Q3: B) Add padding to make it stride-33")
print("    Common pattern: float smem[32][33] instead of [32][32]")
print("    The extra column breaks the bank conflict pattern.")

### Checking for Bank Conflicts with Nsight Compute

```bash
# Profile and check for bank conflicts
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python your_kernel.py
```

These metrics show:
- `..._op_ld.sum`: Bank conflicts on shared memory loads
- `..._op_st.sum`: Bank conflicts on shared memory stores

**Goal:** Both should be 0 for optimal performance.

---
## Summary

### Key Takeaways

1. **Shared memory has 32 banks** (4-byte word granularity)
2. **Bank = (address / 4) % 32**
3. **Bank conflicts serialize access** (up to 32x slowdown)
4. **Stride-32 is worst**, stride-33 is conflict-free
5. **Padding fixes power-of-2 conflicts** (`float smem[N][N+1]`)
6. **Broadcast is free** (all threads same address → no conflict)

### Conflict-Free Patterns

| Pattern | Bank Access | Conflicts |
|---------|------------|----------|
| `smem[tid]` | Sequential | None |
| `smem[tid * 33]` | All different | None |
| `smem[const]` | Broadcast | None |
| `smem[tid ^ lane]` | XOR shuffle | None |

### Tomorrow: Pipelining

Now that we understand memory access patterns at both global and shared memory levels, we'll learn to **hide latency** by overlapping memory loads with computation.