# Week 1, Day 4: Your First Triton Kernel

**Time:** ~1 hour

**Goal:** Write vector addition in Triton and understand index arithmetic.

## The Challenge

Today we move from using libraries to writing our own GPU code.

**Why Triton?**
- Python-like syntax (easier than CUDA C++)
- Automatic memory coalescing
- Block-level programming (you think in tiles, not threads)

In [None]:
import numpy as np
import time

try:
    import torch
    import triton
    import triton.language as tl
    TRITON_AVAILABLE = True
    print(f"Triton version: {triton.__version__}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
except ImportError as e:
    TRITON_AVAILABLE = False
    print(f"Triton or PyTorch not available: {e}")
    print("Install with: pip install torch triton")

---
## Step 1: The Challenge (5 min)

**Task:** Implement `C = A + B` for vectors of 1M elements.

This is the "Hello World" of GPU programming.

In [None]:
# Our target: element-wise vector addition
# C[i] = A[i] + B[i] for all i

# NumPy version (CPU)
def vector_add_numpy(a, b):
    return a + b

# Test
a = np.array([1, 2, 3, 4], dtype=np.float32)
b = np.array([10, 20, 30, 40], dtype=np.float32)
c = vector_add_numpy(a, b)
print(f"A: {a}")
print(f"B: {b}")
print(f"C = A + B: {c}")

---
## Step 2: Explore - Index Arithmetic (15 min)

### The Key Concept: Program IDs and Block Processing

In Triton, you don't write code for individual threads. Instead:
- Each **program** processes a **block** of elements
- `tl.program_id(axis)` tells you which block you are
- You compute offsets to load/store the right data

In [None]:
# Simulate Triton's indexing model
def simulate_triton_indexing(n_elements, block_size):
    """Understand how Triton programs divide work."""
    
    n_programs = (n_elements + block_size - 1) // block_size
    
    print(f"Total elements: {n_elements}")
    print(f"Block size: {block_size}")
    print(f"Number of programs: {n_programs}")
    print()
    
    for pid in range(min(n_programs, 4)):  # Show first 4 programs
        # Each program computes its starting offset
        block_start = pid * block_size
        
        # Offsets within this program's block
        offsets = block_start + np.arange(block_size)
        
        # Mask for elements that are valid (in-bounds)
        mask = offsets < n_elements
        
        valid_offsets = offsets[mask]
        
        print(f"Program {pid}:")
        print(f"  block_start = {pid} * {block_size} = {block_start}")
        print(f"  offsets = [{block_start}, {block_start+1}, ..., {block_start+block_size-1}]")
        print(f"  valid elements: {len(valid_offsets)} (indices {valid_offsets[0]}..{valid_offsets[-1]})")
        print()
    
    if n_programs > 4:
        print(f"... ({n_programs - 4} more programs)")

simulate_triton_indexing(n_elements=1000, block_size=256)

### The Index Formula

```python
# For 1D arrays:
block_start = program_id * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
```

This is the same as:
```python
# In CUDA terms:
global_idx = blockIdx.x * blockDim.x + threadIdx.x
```

---
## Step 3: The Concept - Triton Kernel Structure (10 min)

A Triton kernel has this structure:

```python
@triton.jit
def my_kernel(
    # Pointers to input/output tensors
    input_ptr, output_ptr,
    # Size information
    n_elements,
    # Compile-time constants
    BLOCK_SIZE: tl.constexpr,
):
    # 1. Calculate which elements this program handles
    pid = tl.program_id(axis=0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    
    # 2. Create mask for valid elements
    mask = offsets < n_elements
    
    # 3. Load data
    data = tl.load(input_ptr + offsets, mask=mask)
    
    # 4. Compute
    result = data * 2  # Example: double each element
    
    # 5. Store result
    tl.store(output_ptr + offsets, result, mask=mask)
```

---
## Step 4: Code It - Vector Addition Kernel (30 min)

In [None]:
if TRITON_AVAILABLE:
    @triton.jit
    def vector_add_kernel(
        # Pointers to the input and output vectors
        a_ptr,
        b_ptr,
        c_ptr,
        # Number of elements in the vectors
        n_elements,
        # Block size (must be a power of 2)
        BLOCK_SIZE: tl.constexpr,
    ):
        """Compute C = A + B element-wise."""
        
        # Step 1: Identify which block this program handles
        pid = tl.program_id(axis=0)  # Which block am I?
        
        # Step 2: Calculate the offsets for this block
        # Each program handles BLOCK_SIZE consecutive elements
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        
        # Step 3: Create a mask for valid elements
        # (handles case where n_elements is not divisible by BLOCK_SIZE)
        mask = offsets < n_elements
        
        # Step 4: Load data from A and B
        a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
        b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
        
        # Step 5: Compute the sum
        c = a + b
        
        # Step 6: Store the result
        tl.store(c_ptr + offsets, c, mask=mask)
    
    print("Kernel compiled successfully!")

In [None]:
if TRITON_AVAILABLE:
    def vector_add_triton(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """Python wrapper to launch the Triton kernel."""
        
        # Ensure inputs are on GPU and contiguous
        assert a.is_cuda and b.is_cuda
        assert a.shape == b.shape
        
        # Allocate output tensor
        c = torch.empty_like(a)
        
        n_elements = a.numel()
        
        # Choose block size (power of 2, typically 256-1024)
        BLOCK_SIZE = 1024
        
        # Calculate grid size (number of programs to launch)
        grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
        
        # Launch the kernel!
        vector_add_kernel[grid](
            a, b, c,
            n_elements,
            BLOCK_SIZE=BLOCK_SIZE,
        )
        
        return c
    
    print("Wrapper function defined!")

In [None]:
if TRITON_AVAILABLE:
    # Test correctness
    print("Testing correctness...")
    
    # Create test data
    n = 1000
    a = torch.randn(n, device='cuda', dtype=torch.float32)
    b = torch.randn(n, device='cuda', dtype=torch.float32)
    
    # Compute with Triton
    c_triton = vector_add_triton(a, b)
    
    # Compute with PyTorch (reference)
    c_torch = a + b
    
    # Check if they match
    max_diff = (c_triton - c_torch).abs().max().item()
    print(f"Max difference: {max_diff}")
    print(f"Results match: {torch.allclose(c_triton, c_torch)}")

In [None]:
if TRITON_AVAILABLE:
    # Benchmark
    print("\nBenchmarking...")
    
    n = 10_000_000  # 10M elements
    a = torch.randn(n, device='cuda', dtype=torch.float32)
    b = torch.randn(n, device='cuda', dtype=torch.float32)
    
    # Warmup
    for _ in range(10):
        _ = vector_add_triton(a, b)
    torch.cuda.synchronize()
    
    # Benchmark Triton
    start = time.perf_counter()
    for _ in range(100):
        c = vector_add_triton(a, b)
    torch.cuda.synchronize()
    triton_time = (time.perf_counter() - start) / 100 * 1000
    
    # Benchmark PyTorch
    start = time.perf_counter()
    for _ in range(100):
        c = a + b
    torch.cuda.synchronize()
    torch_time = (time.perf_counter() - start) / 100 * 1000
    
    # Calculate bandwidth
    # We read 2 vectors, write 1 vector, each of n elements * 4 bytes
    bytes_total = 3 * n * 4
    triton_bw = bytes_total / (triton_time / 1000) / 1e9  # GB/s
    torch_bw = bytes_total / (torch_time / 1000) / 1e9  # GB/s
    
    print(f"Vector size: {n:,} elements ({n * 4 / 1e6:.1f} MB each)")
    print(f"")
    print(f"Triton: {triton_time:.3f} ms, {triton_bw:.0f} GB/s")
    print(f"PyTorch: {torch_time:.3f} ms, {torch_bw:.0f} GB/s")
    print(f"")
    print(f"Speedup: {torch_time/triton_time:.2f}x")

### Understanding the Results

For element-wise operations:
- Performance is **memory-bound** (limited by bandwidth, not compute)
- Triton should match PyTorch (both use efficient memory access)
- The metric to watch is **bandwidth** (GB/s), not FLOPS

**Peak bandwidth** (approximate):
- H100: ~3.35 TB/s
- A100: ~2.0 TB/s
- RTX 4090: ~1.0 TB/s

---
## Step 5: Verify - Exercises (10 min)

### Exercise 1: Modify the kernel

Change the kernel to compute `C = A * B` (element-wise multiplication).

In [None]:
if TRITON_AVAILABLE:
    @triton.jit
    def vector_mul_kernel(
        a_ptr, b_ptr, c_ptr,
        n_elements,
        BLOCK_SIZE: tl.constexpr,
    ):
        """TODO: Implement C = A * B"""
        pid = tl.program_id(axis=0)
        offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        
        a = tl.load(a_ptr + offsets, mask=mask)
        b = tl.load(b_ptr + offsets, mask=mask)
        
        # TODO: Change this line
        c = a * b  # Changed from a + b
        
        tl.store(c_ptr + offsets, c, mask=mask)
    
    # Test it
    a = torch.tensor([1, 2, 3, 4], device='cuda', dtype=torch.float32)
    b = torch.tensor([10, 20, 30, 40], device='cuda', dtype=torch.float32)
    c = torch.empty_like(a)
    
    grid = (triton.cdiv(a.numel(), 256),)
    vector_mul_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE=256)
    
    print(f"A: {a}")
    print(f"B: {b}")
    print(f"C = A * B: {c}")
    print(f"Expected: {a * b}")

### Exercise 2: Index calculation quiz

In [None]:
# Q1: If BLOCK_SIZE=256 and pid=3, what's the first element processed?
BLOCK_SIZE = 256
pid = 3
first_element = pid * BLOCK_SIZE
print(f"Q1: First element for pid={pid}: {first_element}")

# Q2: For n_elements=1000, how many programs are needed with BLOCK_SIZE=256?
n_elements = 1000
num_programs = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
print(f"Q2: Number of programs: {num_programs}")

# Q3: In the last program (pid=3), how many valid elements are there?
last_pid = num_programs - 1
block_start = last_pid * BLOCK_SIZE
valid_elements = n_elements - block_start
print(f"Q3: Valid elements in last program: {valid_elements}")

---
## Summary

| Concept | Key Point |
|---------|----------|
| Program ID | `tl.program_id(axis)` - which block you are |
| Offsets | `pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)` |
| Mask | Handle out-of-bounds elements |
| Load/Store | `tl.load()` / `tl.store()` with pointer + offset |
| Grid | Number of programs to launch |

### Key Formula

```python
# The universal index pattern:
pid = tl.program_id(axis=0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
```

---
## Next: Day 5 - Memory Hierarchy

Tomorrow we'll learn why our naive matmul would be 10-50x slower than CuPy - it's all about memory access patterns.

[Continue to 05_memory_hierarchy.ipynb](./05_memory_hierarchy.ipynb)