# Week 1, Day 3: GPU Architecture

**Time:** ~1 hour

**Goal:** Understand how GPUs execute code - SMs, warps, and the SIMT model.

## The Challenge

CuPy gave us 50-100x speedup "for free." But to write our own fast kernels, we need to understand:
1. How GPU hardware is organized
2. How thousands of threads execute in parallel
3. What makes GPU code fast (or slow)

In [None]:
import numpy as np

try:
    import cupy as cp
    GPU_AVAILABLE = True
except ImportError:
    GPU_AVAILABLE = False
    print("CuPy not available - conceptual content still applies")

---
## Step 1: The Challenge (5 min)

**Question:** How can a GPU execute millions of threads efficiently?

Let's query our GPU to see what we're working with.

In [None]:
if GPU_AVAILABLE:
    props = cp.cuda.runtime.getDeviceProperties(0)
    
    print("=" * 50)
    print("GPU SPECIFICATIONS")
    print("=" * 50)
    print(f"Name: {props['name'].decode()}")
    print(f"Compute Capability: {props['major']}.{props['minor']}")
    print(f"\nStreaming Multiprocessors (SMs): {props['multiProcessorCount']}")
    print(f"Max threads per SM: {props['maxThreadsPerMultiProcessor']}")
    print(f"Max threads per block: {props['maxThreadsPerBlock']}")
    print(f"\nTotal memory: {props['totalGlobalMem'] / 1e9:.1f} GB")
    print(f"Shared memory per block: {props['sharedMemPerBlock'] / 1024:.0f} KB")
    print(f"Registers per block: {props['regsPerBlock']}")
    print(f"\nWarp size: {props['warpSize']} threads")
else:
    print("Example: NVIDIA H100 GPU")
    print("  SMs: 132")
    print("  Max threads per SM: 2048")
    print("  Warp size: 32 threads")

---
## Step 2: Explore - The GPU Hierarchy (15 min)

### GPU Organization

```
GPU Device
├── SM 0 (Streaming Multiprocessor)
│   ├── Warp 0 (32 threads)
│   ├── Warp 1 (32 threads)
│   ├── ...
│   └── Shared Memory (per-SM)
├── SM 1
├── ...
└── Global Memory (HBM)
```

### Key Concepts

| Level | What it is | Size |
|-------|-----------|------|
| GPU | The whole device | 1 |
| SM | Streaming Multiprocessor | ~100-200 per GPU |
| Warp | 32 threads in lockstep | 64 per SM (max) |
| Thread | Individual execution context | 1000s per SM |

In [None]:
# Interactive visualization concept
print("GPU HIERARCHY VISUALIZATION")
print("="*50)
print("")
print("  [GPU Device]")
print("  │")
print("  ├── [SM 0] ─── [Warp 0] ─ T0  T1  T2  ... T31")
print("  │         │── [Warp 1] ─ T32 T33 T34 ... T63")
print("  │         │── ...")
print("  │         └── [Shared Memory: 48-228 KB]")
print("  │")
print("  ├── [SM 1] ─── [Warp 0] ─ ...")
print("  │         └── ...")
print("  │")
print("  └── [Global Memory: 40-192 GB HBM]")
print("")
print("")
print("For interactive visualization, see:")
print("  ../lessons/gpu-architecture.html")

### The Warp: The Most Important Concept

**A warp is 32 threads that execute the SAME instruction at the SAME time.**

This is called SIMT (Single Instruction, Multiple Threads).

Think of it like:
- 32 soldiers marching in perfect sync
- Each soldier carries different data
- But they all do the same action at the same time

In [None]:
# Simulate warp execution
def simulate_warp_execution(instructions, data_per_thread):
    """Simulate how a warp executes instructions."""
    num_threads = 32
    thread_data = data_per_thread.copy()
    
    print(f"Warp with {num_threads} threads")
    print(f"Initial data: {thread_data[:8]}... (showing first 8 threads)")
    print()
    
    for inst in instructions:
        print(f"All 32 threads execute: {inst}")
        if inst == "MULTIPLY BY 2":
            thread_data = [x * 2 for x in thread_data]
        elif inst == "ADD 10":
            thread_data = [x + 10 for x in thread_data]
        print(f"  Result: {thread_data[:8]}...")
        print()
    
    return thread_data

# Example: Each thread starts with its thread ID
initial_data = list(range(32))
instructions = ["MULTIPLY BY 2", "ADD 10"]

result = simulate_warp_execution(instructions, initial_data)

---
## Step 3: The Concept - Warp Divergence (10 min)

### What happens when threads take different paths?

If threads in a warp hit an `if-else`, the warp must execute BOTH paths:
1. Execute the `if` branch (threads in `else` wait)
2. Execute the `else` branch (threads in `if` wait)

This is called **warp divergence** and it kills performance.

In [None]:
def simulate_divergence(thread_ids, condition_fn):
    """Simulate warp divergence."""
    # Evaluate condition for each thread
    conditions = [condition_fn(tid) for tid in thread_ids]
    
    true_threads = [tid for tid, c in zip(thread_ids, conditions) if c]
    false_threads = [tid for tid, c in zip(thread_ids, conditions) if not c]
    
    print("Warp executes if-else statement")
    print(f"Condition: thread_id % 2 == 0")
    print()
    print(f"Step 1: Execute IF branch")
    print(f"  Active threads ({len(true_threads)}): {true_threads[:8]}...")
    print(f"  Waiting threads ({len(false_threads)}): {false_threads[:8]}...")
    print()
    print(f"Step 2: Execute ELSE branch")
    print(f"  Active threads ({len(false_threads)}): {false_threads[:8]}...")
    print(f"  Waiting threads ({len(true_threads)}): {true_threads[:8]}...")
    print()
    print(f"Total cycles: 2 (50% efficiency)")
    print(f"Without divergence: 1 cycle (100% efficiency)")

# Half the threads go one way, half go another
thread_ids = list(range(32))
simulate_divergence(thread_ids, lambda tid: tid % 2 == 0)

### Avoiding Divergence

**Bad:** Condition based on thread ID within warp
```python
if thread_id % 2 == 0:  # Half diverge!
    do_something()
else:
    do_other()
```

**Good:** Condition based on data (all threads likely same path)
```python
if input_value > threshold:  # Usually all same
    do_something()
```

**Better:** No branches at all
```python
result = a * mask + b * (1 - mask)  # Branchless
```

---
## Step 4: Code It - Blocks and Grids (30 min)

### Programming Model: Blocks and Grids

When you launch a GPU kernel, you specify:
- **Grid**: How many blocks to launch
- **Block**: How many threads per block

```
Grid (your problem)
├── Block 0 (e.g., 256 threads = 8 warps)
│   ├── Warp 0: threads 0-31
│   ├── Warp 1: threads 32-63
│   └── ...
├── Block 1
└── Block N-1
```

In [None]:
def calculate_kernel_config(problem_size, threads_per_block=256):
    """Calculate how to map a problem to GPU threads."""
    
    # Number of blocks needed (round up)
    num_blocks = (problem_size + threads_per_block - 1) // threads_per_block
    
    # Warps per block
    warps_per_block = threads_per_block // 32
    
    # Total threads launched
    total_threads = num_blocks * threads_per_block
    
    print(f"Problem size: {problem_size:,} elements")
    print(f"Threads per block: {threads_per_block}")
    print(f"Warps per block: {warps_per_block}")
    print(f"Number of blocks: {num_blocks:,}")
    print(f"Total threads: {total_threads:,}")
    print(f"Extra threads (wasted): {total_threads - problem_size}")
    
    return num_blocks, threads_per_block

# Example: Processing a 1M element vector
calculate_kernel_config(1_000_000)

In [None]:
# Example: 2D problem (matrix)
def calculate_2d_config(rows, cols, block_x=16, block_y=16):
    """Calculate grid configuration for 2D problem."""
    
    grid_x = (cols + block_x - 1) // block_x
    grid_y = (rows + block_y - 1) // block_y
    
    threads_per_block = block_x * block_y
    total_blocks = grid_x * grid_y
    total_threads = total_blocks * threads_per_block
    
    print(f"Matrix size: {rows} x {cols} = {rows*cols:,} elements")
    print(f"Block size: {block_x} x {block_y} = {threads_per_block} threads")
    print(f"Grid size: {grid_x} x {grid_y} = {total_blocks:,} blocks")
    print(f"Total threads: {total_threads:,}")
    
    return (grid_x, grid_y), (block_x, block_y)

# Example: 1024x1024 matrix
calculate_2d_config(1024, 1024)

### Thread Indexing

Each thread knows its position:
- `threadIdx.x`: Position within block (0 to blockDim.x-1)
- `blockIdx.x`: Which block (0 to gridDim.x-1)
- Global index: `blockIdx.x * blockDim.x + threadIdx.x`

In [None]:
def simulate_thread_indexing(grid_dim, block_dim):
    """Simulate how threads calculate their global index."""
    
    print(f"Grid: {grid_dim} blocks, Block: {block_dim} threads")
    print()
    
    for block_idx in range(min(grid_dim, 3)):  # Show first 3 blocks
        print(f"Block {block_idx}:")
        for thread_idx in range(min(block_dim, 8)):  # Show first 8 threads
            global_idx = block_idx * block_dim + thread_idx
            print(f"  threadIdx={thread_idx}, blockIdx={block_idx} -> global={global_idx}")
        if block_dim > 8:
            print(f"  ...")
        print()

simulate_thread_indexing(grid_dim=4, block_dim=32)

### Occupancy

**Occupancy** = active warps / maximum warps per SM

Higher occupancy often (but not always) means better performance.

Occupancy is limited by:
1. Registers per thread
2. Shared memory per block
3. Threads per block

In [None]:
def calculate_occupancy(threads_per_block, regs_per_thread, smem_per_block,
                       max_warps_per_sm=64, max_regs_per_sm=65536, max_smem_per_sm=102400):
    """Calculate theoretical occupancy."""
    
    warps_per_block = threads_per_block // 32
    regs_per_block = threads_per_block * regs_per_thread
    
    # Blocks limited by warps
    blocks_by_warps = max_warps_per_sm // warps_per_block
    
    # Blocks limited by registers
    blocks_by_regs = max_regs_per_sm // regs_per_block if regs_per_block > 0 else 999
    
    # Blocks limited by shared memory
    blocks_by_smem = max_smem_per_sm // smem_per_block if smem_per_block > 0 else 999
    
    # Actual blocks limited by most constrained resource
    active_blocks = min(blocks_by_warps, blocks_by_regs, blocks_by_smem)
    active_warps = active_blocks * warps_per_block
    occupancy = active_warps / max_warps_per_sm * 100
    
    print(f"Configuration:")
    print(f"  Threads/block: {threads_per_block}")
    print(f"  Registers/thread: {regs_per_thread}")
    print(f"  Shared memory/block: {smem_per_block} bytes")
    print()
    print(f"Limits:")
    print(f"  By warps: {blocks_by_warps} blocks")
    print(f"  By registers: {blocks_by_regs} blocks")
    print(f"  By shared memory: {blocks_by_smem} blocks")
    print()
    print(f"Result:")
    print(f"  Active blocks/SM: {active_blocks}")
    print(f"  Active warps/SM: {active_warps}")
    print(f"  Occupancy: {occupancy:.0f}%")
    
    return occupancy

# Example configurations
print("=" * 50)
print("High occupancy config:")
print("=" * 50)
calculate_occupancy(256, 32, 0)

print()
print("=" * 50)
print("Register-limited config:")
print("=" * 50)
calculate_occupancy(256, 128, 0)

---
## Step 5: Verify - Quiz (10 min)

In [None]:
# Q1: How many threads in a warp?
print("Q1: A warp contains 32 threads")
print("    This is fixed for all NVIDIA GPUs")

In [None]:
# Q2: What's the global index formula?
print("Q2: global_idx = blockIdx.x * blockDim.x + threadIdx.x")
print("")
print("For 2D:")
print("  row = blockIdx.y * blockDim.y + threadIdx.y")
print("  col = blockIdx.x * blockDim.x + threadIdx.x")

In [None]:
# Q3: What happens with divergence?
print("Q3: When threads in a warp diverge (if-else):")
print("    Both paths execute serially")
print("    Inactive threads are masked")
print("    Throughput is reduced")

In [None]:
# Q4: What can threads in the same block do?
print("Q4: Threads in the same block can:")
print("    - Share data via shared memory")
print("    - Synchronize with __syncthreads()")
print("")
print("    Threads in different blocks CANNOT:")
print("    - Share data directly")
print("    - Synchronize with each other")

---
## Summary

| Concept | Key Point |
|---------|----------|
| SM | Independent execution unit on GPU |
| Warp | 32 threads executing in lockstep |
| Block | Group of warps sharing resources |
| Grid | All blocks for your kernel |
| Divergence | Kills performance - avoid if possible |
| Occupancy | More active warps = better latency hiding |

### Interactive Resources

For interactive visualizations of these concepts, see:
- [GPU Architecture Lesson](../lessons/gpu-architecture.html) - Warp execution demo, occupancy calculator

### What's Next?

Now that we understand GPU architecture, let's write our first kernel!

---
## Next: Day 4 - Your First Triton Kernel

Tomorrow we'll write vector addition in Triton and understand index arithmetic.

[Continue to 04_first_triton_kernel.ipynb](./04_first_triton_kernel.ipynb)