# Programming on GPUs

This notebook teaches the basics of GPU programming through hands-on exercises. We start with [Numba](https://numba.pydata.org/), a just-in-time compiler for Python that provides low-level GPU control, then move to [Triton](https://openai.com/index/triton/), OpenAI's high-level Python-like GPU programming language.

**Learning approach:** This notebook emphasizes interactive coding with minimal upfront theory. You'll learn by doing. If you get stuck, ask for hints from your favorite chat assistant without requesting the complete solution.

**Sources:** 
- Nvidia [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/programming-model.html)
- [GPU-Puzzles](https://github.com/srush/GPU-Puzzles) by Sasha Rush

In [None]:
import numba
import numpy as np
from numba import cuda

In [None]:
import warnings
warnings.filterwarnings(
    action="ignore", category=numba.NumbaPerformanceWarning, module="numba"
)

## Core GPU Concepts

Before we start coding, let's understand the GPU's execution model:

- **Streaming Multiprocessor (SM)**: The GPU's computation unit (analogous to a CPU core)
- **Thread**: The smallest unit of execution (processes one element)
- **Thread Block (Block)**: A group of threads guaranteed to run on a single SM (can share memory and synchronize)
- **Grid**: Thread blocks are organized into a 1D, 2D, or 3D grid
- **Warp**: Within a thread block, threads are grouped into warps of 32 threads. All threads in a warp execute the same instruction simultaneously (SIMT: Single-Instruction Multiple-Threads)

**Mental model:** Grid ‚Üí Blocks ‚Üí Threads ‚Üí Warps

Let's start coding!

### Puzzle 1: Map

**Goal:** Add 10 to each element of an array using parallel threads.

In [None]:
def map_spec(a):
    return a + 10

# Size of our array
SIZE = 4

# Create input and output arrays
a = np.arange(SIZE, dtype=np.float32)  # [0, 1, 2, 3]
out = np.zeros(SIZE, dtype=np.float32)

map_spec(a)

**Task:** Implement this using Numba so that each thread adds 10 to exactly one element of the array.

**Hint:** Use `cuda.threadIdx.x` to get the current thread's index within its block.

In [None]:
# Define the CUDA kernel
@cuda.jit
def map_kernel(out, a):
    # Get the thread index
    i = cuda.threadIdx.x
    # Each thread adds 10 to one element
    # your code here


# Copy arrays to GPU
a_device = cuda.to_device(a)
out_device = cuda.to_device(out)

# kernel[grid, block](args)
# Launch kernel: grid = 1 block, block = SIZE threads
map_kernel[1, SIZE](out_device, a_device)

# Copy result back to CPU
result = out_device.copy_to_host()

# Verify result
expected = map_spec(a)
print(f"Input:    {a}")
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

### Puzzle 2: Vector Addition

**Goal:** Add two vectors element-wise using parallel threads.

In [None]:
def zip_spec(a, b):
    return a + b

out = np.zeros(SIZE)
a = np.arange(SIZE)
b = np.arange(SIZE)
zip_spec(a,b)

In [None]:
# Define the CUDA kernel
@cuda.jit
def zip_kernel(out, a, b):
    # Get the thread index
    i = cuda.threadIdx.x
    # your code here

# A function to move vectors on device
def init_pb(a=a, b=b, out=out):
    a_device = cuda.to_device(a)
    b_device = cuda.to_device(b)
    out_device = cuda.to_device(out)
    return a_device, b_device, out_device

a_device, b_device, out_device = init_pb()

# Launch kernel: 1 block, SIZE threads
zip_kernel[1, SIZE](out_device, a_device, b_device)

# Copy result back to CPU
result = out_device.copy_to_host()

# Verify result
expected = zip_spec(a, b)
print(f"Input a:  {a}")
print(f"Input b:  {b}")
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

**Experiment:** What happens if you launch more threads than the array size?

In [None]:
a_device, b_device, out_device = init_pb()

NUM_TRHEADS = 2*SIZE
zip_kernel[1, NUM_TRHEADS](out_device, a_device, b_device)

# Copy result back to CPU
result = out_device.copy_to_host()

# Verify result
expected = zip_spec(a, b)
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

**Result:** Still works but unsafe! Excess threads access out-of-bounds memory, which can cause crashes or silent data corruption.

**Task:** Add a guard clause to prevent threads from accessing memory beyond the array bounds.

In [None]:
# CUDA kernel with Guard
@cuda.jit
def zip_guard_kernel(out, a, b, size):
    # Get the thread index
    i = cuda.threadIdx.x
    # your code here

a_device, b_device, out_device = init_pb()

NUM_TRHEADS = 2*SIZE
zip_guard_kernel[1, NUM_TRHEADS](out_device, a_device, b_device, SIZE)

# Copy result back to CPU
result = out_device.copy_to_host()

# Verify result
expected = zip_spec(a, b)
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

### Puzzle 3: 2D Matrices

**Goal:** Apply the map operation to a 2D matrix using 2D thread blocks.

In [None]:
a = np.arange(SIZE * SIZE).reshape((SIZE, SIZE))
out = map_spec(a)
out

**Key insight:** Thread blocks can be organized in 2D or 3D shapes, which simplifies mapping threads to 2D/3D data structures.

**Task:** Use a 2D thread block where each thread handles one matrix element.

**Hint:** Use `cuda.threadIdx.x` and `cuda.threadIdx.y` to get both coordinates.

In [None]:
@cuda.jit
def map_2d_kernel(out, a, size):
    i = cuda.threadIdx.x
    j = cuda.threadIdx.y
    # your code here

a_device, b_device, out_device = init_pb(a=a, out=np.zeros_like(out))

TRHEAD_BLOCK = (SIZE, SIZE)
map_2d_kernel[1, TRHEAD_BLOCK](out_device, a_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = map_spec(a)
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

### Puzzle 4: Broadcasting

**Goal:** Add two vectors with broadcasting (column + row ‚Üí matrix).

In [None]:
a = np.arange(SIZE).reshape(SIZE, 1)
b = np.arange(SIZE).reshape(1, SIZE)
out = a + b
out

In [None]:
@cuda.jit
def broadcast_kernel(out, a, b, size):
    i = cuda.threadIdx.x
    j = cuda.threadIdx.y
    # your code here
    

a_device, b_device, out_device = init_pb(a=a, b=b, out=np.zeros_like(out))

THREAD_BLOCK = (2*SIZE, 3*SIZE)
broadcast_kernel[1, THREAD_BLOCK](out_device, a_device, b_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = a + b
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

**New concept:** So far we've used only **thread blocks**. Now let's use the **grid** dimension.

Like thread blocks, grids can be 1D, 2D, or 3D.

**Task:** Use a 2D grid with a single thread per block to compute the broadcast addition.

**Hint:** Use `cuda.blockIdx.x` and `cuda.blockIdx.y` to get the block's position in the grid.

In [None]:
@cuda.jit
def broadcast_grid_kernel(out, a, b, size):
    i = cuda.blockIdx.x 
    j = cuda.blockIdx.y
    # your code here


a_device, b_device, out_device = init_pb(a=a, b=b, out=np.zeros_like(out))

# 1 thread per block, 2D grid
THREADS = 1
GRID = (SIZE, SIZE)
broadcast_grid_kernel[GRID, THREADS](out_device, a_device, b_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = a + b
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

**Next challenge:** Implement the same operation using a 1D grid and 1D thread blocks.

**Hint:** You'll need to compute 2D indices (i, j) from 1D block and thread indices.

In [None]:
@cuda.jit
def broadcast_grid_kernel(out, a, b, size):
    # your code here
    

a_device, b_device, out_device = init_pb(a=a, b=b, out=np.zeros_like(out))

THREADS = SIZE
GRID = SIZE
broadcast_grid_kernel[GRID, THREADS](out_device, a_device, b_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = a + b
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

**Final challenge:** Use a 2D grid with 2D thread blocks.

Here we configure `cuda.blockDim.x = cuda.blockDim.y = 2` (each block has 2√ó2=4 threads) and `SIZE//2 = 2` (we have 2√ó2=4 blocks).

**Understanding the indexing:** Each element (i, j) in the output is computed by combining block and thread indices:

| blockIdx.x | blockIdx.y | threadIdx.x | threadIdx.y | **i** | **j** | Computes |
|------------|------------|-------------|-------------|-------|-------|----------|
| 0 | 0 | 0 | 0 | **0** | **0** | out[0,0] |
| 0 | 0 | 1 | 0 | **1** | **0** | out[1,0] |
| 0 | 0 | 0 | 1 | **0** | **1** | out[0,1] |
| 0 | 0 | 1 | 1 | **1** | **1** | out[1,1] |
| 1 | 0 | 0 | 0 | **2** | **0** | out[2,0] |
| 1 | 0 | 1 | 0 | **3** | **0** | out[3,0] |
| 0 | 1 | 0 | 0 | **0** | **2** | out[0,2] |
| 0 | 1 | 1 | 1 | **1** | **3** | out[1,3] |
| 1 | 1 | 0 | 0 | **2** | **2** | out[2,2] |
| 1 | 1 | 1 | 1 | **3** | **3** | out[3,3] |

**Formula:** `i = blockIdx.x * blockDim.x + threadIdx.x` (similarly for j)

In [None]:
@cuda.jit
def broadcast_grid_kernel(out, a, b, size):
    # your code here
    

a_device, b_device, out_device = init_pb(a=a, b=b, out=np.zeros_like(out))

THREADS = (SIZE//2 , SIZE//2)
GRID = (SIZE//2, SIZE//2)  
broadcast_grid_kernel[GRID, THREADS](out_device, a_device, b_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = a + b
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

## GPU Memory Hierarchy

<div>
<img src="https://docs.nvidia.com/cuda/cuda-programming-guide/_images/gpu-cpu-system-diagram.png" width="700"/>
</div>

**Key memory types:**

- **Global Memory (DRAM)**: Large but slow (~100s of cycles latency)
  - Accessible to all SMs in the GPU
  - System memory (host DRAM) is even slower (PCIe transfer required)
  
- **Shared Memory**: Small but fast (~1 cycle latency)
  - On-chip memory shared by threads within a block
  - Programmer-managed cache (you control what gets loaded)
  - Limited size (typically 48-164 KB per SM)
  
- **Registers**: Fastest, private to each thread
  - Directly accessible by thread (no load/store needed)
  - Very limited (register spilling causes performance degradation)

**Performance strategy:** Minimize global memory accesses by using shared memory as a manually-managed cache.

### Puzzle 5: Pooling (Sliding Window Sum)

**Goal:** Compute a sliding window sum where each output element is the sum of up to 3 input elements (current and 2 previous).

In [None]:
def pool_spec(a):
    out = np.zeros(a.shape)
    for i in range(a.shape[0]):
        out[i] = a[max(i - 2, 0) : i + 1].sum()
    return out

SIZE = 8
a = np.arange(SIZE)
out = pool_spec(a)
out

In [None]:
@cuda.jit
def pool_kernel(out, a, size):
    i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    if i < size:
        # Manually compute sum - can't use slicing in CUDA!
        temp_sum = 0.0
        for k in range(max(i - 2, 0), i + 1):
            temp_sum += a[k]  # Use global memory
        out[i] = temp_sum

a_device, b_device, out_device = init_pb(a=a, out=np.zeros_like(out))

THREADS = SIZE//2
GRID = (2,1)  
pool_kernel[GRID, THREADS](out_device, a_device, SIZE)

result = out_device.copy_to_host()

# Verify result
expected = pool_spec(a)
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")
r = 3
print(f"number of access to global memory: 1 + 2 + {THREADS-2} threads x {r} reads = {1+2+(THREADS-2)*r} global reads per block -> {2*(1+2+(THREADS-2)*r)} global reads in total")

**Problem:** The naive implementation accesses global memory many times (12 reads for SIZE=8).

**Solution:** Use shared memory to reduce global memory accesses to just 12 reads total (vs. many more with naive approach).

**Memory hierarchy visualization:**
```
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ  Global Memory (a, out)             ‚îÇ  ‚Üê Slow, accessible to ALL threads
‚îÇ  - High latency (~100s cycles)      ‚îÇ
‚îÇ  - Large capacity (GB)              ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
         ‚Üì                    ‚Üì
   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê          ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
   ‚îÇ Block 0 ‚îÇ          ‚îÇ Block 1 ‚îÇ
   ‚îÇ Shared  ‚îÇ          ‚îÇ Shared  ‚îÇ      ‚Üê Fast, accessible only within block
   ‚îÇ Memory  ‚îÇ          ‚îÇ Memory  ‚îÇ         (~1 cycle)
   ‚îÇ (fast)  ‚îÇ          ‚îÇ (fast)  ‚îÇ
   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò          ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

**How it works - Halo Loading:**
```
Global:      [0, 1, 2, 3, 4, 5, 6, 7]
Block 0 loads:              Block 1 loads:
        ‚Üì                          ‚Üì
Shared: [0, 0, 0, 1, 2, 3] Shared: [2, 3, 4, 5, 6, 7]
         ‚îî‚îÄ‚îò  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò         ‚îî‚îÄ‚îÄ‚îò  ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
        halo    main data          halo    main data
        (boundary padding)         (overlap with Block 0)
```

**Key constraints:**
1. Shared memory size must be a **compile-time constant** (not runtime variable)
2. After loading shared memory, call `cuda.syncthreads()` to ensure all threads see the data before using it

**Task:** Implement the pooling kernel using shared memory with halo zones.

In [None]:
TPB = 4  # Threads per block
SharedMem = TPB + 2 # cannot be computed at runtime
@cuda.jit
def pool_kernel_shared(out, a, size):
    # Allocate shared memory with HALO (extra elements for boundary)
    # Need TPB + 2 extra elements (for the 2-element lookback)
    shared = cuda.shared.array(SharedMem, numba.float32)
    # your code here

a_device, b_device, out_device = init_pb(a=a, out=np.zeros_like(out))


GRID = (SIZE // TPB, 1)  # (2, 1) for SIZE=8, TPB=4
pool_kernel_shared[GRID, TPB](out_device, a_device, SIZE)

result = out_device.copy_to_host()

expected = pool_spec(a)
print(f"Output:   {result}")
print(f"Expected: {expected}")
print(f"Correct:  {np.allclose(result, expected)}")

### Puzzle 6: Dot Product (Parallel Reduction)

**Goal:** Compute dot product of two vectors: `dot(a, b) = sum(a[i] * b[i])`

**Challenge:** Serial code is trivial, but parallel reduction is complex because we need to **combine results from multiple threads**.

In [None]:
def dot_spec(a,b):
    tot = 0
    for i in range(len(a)):
        tot += a[i]*b[i]
    return tot

SIZE = 8
a = np.arange(SIZE, dtype=np.float32)
b = np.arange(SIZE, dtype=np.float32)
dot_spec(a,b)

**Solution approach:** Tree-based reduction within each block, then atomic addition across blocks.

**Visual - Tree Reduction:**

![Tree-based parallel sum](https://www.cs.uaf.edu/2012/fall/cs441/lecture/tree_sum_16td.png)

**How it works:**
1. Each thread computes one element-wise product: `a[i] * b[i]`
2. Store results in shared memory
3. Reduce in log2(n) steps: stride = n/2, n/4, n/8, ..., 1
4. Thread 0 writes block's partial sum to output

In [None]:
def dot_tree(a,b):
    size = len(a)
    shared_mem = np.zeros(size)
    for i in range(size):
       shared_mem[i] = a[i]*b[i]
    stride = size // 2
    while stride > 0:
        for i in range(stride):
            shared_mem[i] += shared_mem[i+stride]
        stride //=2
    return shared_mem[0]
dot_tree(a,b)       

**Task:** Implement tree-based dot product in Numba with the following configuration:
- **256 threads per block** (fixed)
- **Multiple blocks** to cover input size (calculated automatically)

**Hint:** After the within-block reduction, use `cuda.atomic.add()` to safely accumulate partial sums from all blocks.

In [None]:
SIZE = 800
a = np.arange(SIZE, dtype=np.float32)
b = np.arange(SIZE, dtype=np.float32)

threads_per_block = 256
blocks_per_grid = (SIZE + threads_per_block - 1) // threads_per_block
print(f"threads per block: {threads_per_block}")
print(f"blocks per grid: {blocks_per_grid}")

For each block, you can implement the tree-based summation algorithm by first creating a shared memory of size 256 containing the `a[i]*b[i]` with the `i` associated witht the block thread. Then suming it accross the block. The last step consists in adding all the intermediate results: each block adds its result. For this last step, you might want to use `cuda.atomic.add` see below:

### Understanding Atomic Operations

**Why do we need atomics?**

When multiple threads write to the same memory location, non-atomic operations can lose updates due to **race conditions**.

**The Problem Without Atomics:**

When you write `out[0] = out[0] + value`, it's actually 3 separate steps:
```python
# out[0] = out[0] + value breaks down to:
1. READ:   temp = out[0]      # Read current value
2. MODIFY: temp = temp + value # Add to it
3. WRITE:  out[0] = temp       # Write back
```

**Race condition example:**
```
Initial: out[0] = 0

Thread A (Block 0):              Thread B (Block 1):
1. READ: temp_A = 0             
2. MODIFY: temp_A = 0 + 5       
                                 1. READ: temp_B = 0      ‚Üê Still sees 0!
3. WRITE: out[0] = 5            
                                 2. MODIFY: temp_B = 0 + 3 ‚Üê Uses old value!
                                 3. WRITE: out[0] = 3      ‚Üê Overwrites 5!

Final: out[0] = 3  ‚ùå Should be 8!
```

**What Atomics Do:**

`cuda.atomic.add(out, 0, value)` **locks the memory location** during the entire read-modify-write:

```
Initial: out[0] = 0

Thread A (Block 0):              Thread B (Block 1):
üîí LOCK out[0]
1. READ: temp_A = 0             
2. MODIFY: temp_A = 0 + 5       
3. WRITE: out[0] = 5            
üîì UNLOCK out[0]
                                 üîí LOCK out[0]  ‚Üê Must wait for unlock
                                 1. READ: temp_B = 5      ‚Üê Sees updated value!
                                 2. MODIFY: temp_B = 5 + 3
                                 3. WRITE: out[0] = 8
                                 üîì UNLOCK out[0]

Final: out[0] = 8  ‚úÖ Correct!
```

**Key takeaway:** Use atomics when multiple threads update the same memory location.

In [None]:
@cuda.jit
def dot_kernel_numba(a, b, out, size):
    shared = cuda.shared.array(256, numba.float32)
    # your code here
    

expected = np.dot(a, b)
a_device, b_device, out_device = init_pb(a=a, b=b, out=np.zeros_like([expected]))

size = a_device.shape[0]
dot_kernel_numba[blocks_per_grid, threads_per_block](a_device, b_device, out_device, size)

result = out_device.copy_to_host()
print(f"CUDA result: {result[0]}")
print(f"NumPy result: {expected}")
print(f"Match: {np.allclose(result[0], expected)}")

## Numba vs Triton: Conceptual Comparison

### Numba CUDA: Grid and Block Dimensions

**Key concepts:**
- `kernel[grid, block](args)` - Launch syntax
- **Grid** = `(blocks_x, blocks_y, blocks_z)` - How many blocks
- **Block** = `(threads_x, threads_y, threads_z)` - Threads per block
- **Total threads** = `grid_x √ó grid_y √ó grid_z √ó block_x √ó block_y √ó block_z`
- **Manual indexing**: You compute indices using `blockIdx`, `threadIdx`, `blockDim`

### Triton: Program Grid (Higher Abstraction)

In **Triton**, you specify a **program grid** and work with **program IDs**. Triton handles the low-level threading automatically.

**Key concepts:**
- `kernel[grid](args, BLOCK_SIZE=...)` - Launch syntax
- **Grid** = `(programs_x, programs_y, programs_z)` - Number of program instances
- **No explicit thread dimensions** - Triton vectorizes automatically
- **Work with blocks of data** using `tl.arange()` and vectorized operations

---

### Comparison Table

| Aspect | Numba CUDA | Triton |
|--------|------------|--------|
| **Launch syntax** | `kernel[grid, block](args)` | `kernel[grid](args, BLOCK=...)` |
| **Grid represents** | Number of **blocks** | Number of **programs** |
| **Block/Thread control** | Explicit: `(tx, ty, tz)` per block | Abstracted: work on data blocks |
| **Thread indexing** | Manual: `blockIdx`, `threadIdx` | Automatic: `tl.program_id()` + `tl.arange()` |
| **Typical grid** | `(n_blocks_x, n_blocks_y, n_blocks_z)` | `(n_programs_x, n_programs_y, n_programs_z)` |
| **Typical block** | `(threads_x, threads_y, threads_z)` | N/A (implicit in `BLOCK_SIZE`) |
| **Memory access** | Per-thread scalar indexing | Vectorized block operations |
| **Abstraction level** | Low-level (like CUDA C) | High-level (compiler optimizes) |
| **Synchronization** | Explicit: `cuda.syncthreads()` | Mostly automatic |

---

### Key Takeaway

- **Numba CUDA**: You think in terms of **blocks of threads** (2-level hierarchy: grid ‚Üí blocks ‚Üí threads)
- **Triton**: You think in terms of **programs operating on data blocks** (1-level: grid ‚Üí programs, with automatic vectorization)

**Mental model:** Each Triton program ‚âà one CUDA block, but Triton auto-vectorizes the thread-level work.

---

### Practical Example: Vector Addition

Below is the solution from an earlier Numba puzzle:

In [None]:
# CUDA kernel with Guard
@cuda.jit
def zip_guard_kernel(out, a, b, size):
    i = cuda.blockIdx.x * cuda.blockDim.x + cuda.threadIdx.x
    if i < size:
        out[i] = a[i] + b[i]

SIZE = 1000
out = np.zeros(SIZE)
a = np.arange(SIZE)
b = np.arange(SIZE)

a_device, b_device, out_device = init_pb(a=a, b=b, out=out)

threads_per_block = 256
blocks_per_grid = (SIZE + threads_per_block - 1) // threads_per_block
zip_guard_kernel[blocks_per_grid, threads_per_block](out_device, a_device, b_device, SIZE)

# Copy result back to CPU
result = out_device.copy_to_host()

# Verify result
expected = zip_spec(a, b)
print(f"Correct:  {np.allclose(result, expected)}")

### Triton Implementation: Vector Addition

**Key differences from Numba:**
- Works with PyTorch tensors (no manual memory management)
- Processes multiple elements per program (vectorized)
- `BLOCK_SIZE` is a compile-time constant for optimization

In [None]:
import triton
import triton.language as tl
import torch
from einops import rearrange

def get_device(index: int = 0) -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")

In [None]:
@triton.jit
def zip_guard_triton(a_ptr, b_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
    # Triton uses program_id (block index) instead of explicit blockIdx/threadIdx
    # Each "program" processes BLOCK_SIZE elements at once (vectorized)
    pid = tl.program_id(0)
    
    # Triton computes offsets for a *vector* of BLOCK_SIZE elements
    # Unlike Numba where each thread processes 1 element,
    # Triton processes multiple elements per program instance
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    
    # Triton uses mask-based guards for vectorized operations
    # instead of scalar if-statements (if i < size)
    mask = offset < n
    
    # Vectorized load: loads BLOCK_SIZE elements at once with mask
    # Numba loads scalar: a[i]
    a = tl.load(a_ptr + offset, mask=mask)
    b = tl.load(b_ptr + offset, mask=mask)
    
    # Vectorized computation (same as Numba but on vectors)
    c = a + b
    
    # Vectorized store with mask (vs. scalar store in Numba)
    tl.store(out_ptr + offset, c, mask=mask)


# Triton works directly with PyTorch tensors (no manual copy_to_host)
# Numba requires explicit device memory management (init_pb, copy_to_host)
a = torch.randn(SIZE, device=get_device())
b = torch.randn(SIZE, device=get_device())
out = torch.empty_like(a)


# Launch syntax differences:
# - BLOCK_SIZE is a compile-time constant (tl.constexpr) for optimization
# - Only need to specify grid dimensions (not threads_per_block)
# - Triton auto-vectorizes within each program
BLOCK_SIZE = 256
grid = (triton.cdiv(SIZE, BLOCK_SIZE),)  # Only grid size, not block size
zip_guard_triton[grid](a, b, out, SIZE, BLOCK_SIZE=BLOCK_SIZE)

expected = zip_spec(a, b)
print(f"Correct:  {np.allclose(out.cpu().numpy(), expected.cpu().numpy())}")

### Triton Dot Product

**Task:** Implement dot product in Triton.

**Useful functions:**
- [`tl.sum()`](https://triton-lang.org/main/python-api/generated/triton.language.sum.html) - Parallel reduction within a block
- [`tl.atomic_add()`](https://triton-lang.org/main/python-api/generated/triton.language.atomic_add.html) - Atomic addition across programs

In [None]:
@triton.jit
def dot_kernel(
    a_ptr,      # Pointer to first input vector
    b_ptr,      # Pointer to second input vector  
    out_ptr,    # Pointer to output scalar
    size,       # Size of vectors
    BLOCK_SIZE: tl.constexpr,  # Elements per program
):
    # Program ID (analogous to blockIdx.x)
    pid = tl.program_id(0)
    
    # Compute offsets for this program's block
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    # Mask for boundary handling
    mask = offsets < size
    
    # Load data 
    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    
    # Element-wise multiplication
    products = a * b
    
    # Reduce sum within this block (automatic parallel reduction!)
    block_sum = tl.sum(products)
    
    # Atomic add to output (single thread per block does this)
    tl.atomic_add(out_ptr, block_sum)


def dot_triton(a, b):
    """Wrapper function to launch the kernel"""
    # Allocate output
    out = torch.zeros(1, device=a.device, dtype=a.dtype)
    
    # Grid and block configuration
    size = a.shape[0]
    BLOCK_SIZE = 256  
    grid = (triton.cdiv(size, BLOCK_SIZE),)
    
    # Launch kernel
    dot_kernel[grid](a, b, out, size, BLOCK_SIZE=BLOCK_SIZE)
    
    return out


# Usage example
SIZE = 10
a = torch.arange(SIZE, dtype=torch.float32, device=get_device())
b = torch.arange(SIZE, dtype=torch.float32, device=get_device())

result_triton = dot_triton(a, b)
result_torch = torch.dot(a, b)

print(f"Triton result: {result_triton.item()}")
print(f"PyTorch result: {result_torch.item()}")
print(f"Match: {torch.allclose(result_triton, result_torch)}")

### Triton Softmax

**Goal:** Implement softmax for a batch of vectors `z` of shape `(batch_size, dim)`: `torch.softmax(z, dim=1)`

**Key concept - Tensor Strides:**

Strides define how to navigate multi-dimensional tensors in contiguous memory.

**Visual:**
```
Memory: [a, b, c, d, e, f, g, h, i, j, k, l]
Shape (3, 4), stride (4, 1):
  [[a, b, c, d],    ‚Üê skip 4 for next row, skip 1 for next col
   [e, f, g, h],
   [i, j, k, l]]
```

**Why it matters:** To process one row in softmax, we need to:
1. Find the starting address: `row_start_ptr = base_ptr + row_idx * row_stride`
2. Load all elements in that row using column stride

In [None]:
# Contiguous 2D tensor (3√ó4)
x = torch.randn(3, 4)
x.stride()  # (4, 1)
# - Move to next row: skip 4 elements
# - Move to next column: skip 1 element

In [None]:
# Transposed (now 4√ó3)
y = x.t()
y.stride()  # (1, 4)
# - Move to next row: skip 1 element (was column)
# - Move to next column: skip 4 elements (was row)
# Note: y shares memory with x, just different access pattern

**Implementation strategy:**
- Each program processes one complete row independently
- Each program computes softmax over all columns in its row
- Use row_stride to locate the starting address for each row

In [None]:
@triton.jit
def triton_softmax_kernel(x_ptr, y_ptr, x_row_stride, y_row_stride, num_cols, BLOCK_SIZE: tl.constexpr):
    assert num_cols <= BLOCK_SIZE
    # Process each row independently
    # your code here

def triton_softmax(x: torch.Tensor):
    x = x.contiguous()
    # Allocate output tensor
    y = torch.empty_like(x)
    # Determine grid
    M, N = x.shape                          # Number of rows x number of columns
    block_size = triton.next_power_of_2(N)  # Each block contains all the columns
    num_blocks = M                          # Each block is a row
    # Launch kernel
    triton_softmax_kernel[(M,)](
        x_ptr=x, y_ptr=y,
        x_row_stride=x.stride(0), y_row_stride=y.stride(0),
        num_cols=N, BLOCK_SIZE=block_size
    )
    return y

In [None]:
torch.manual_seed(0)
x = torch.randn(1823, 781, device=get_device())
y_triton = triton_softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

In [None]:
y_triton = triton_softmax(x.t())
y_torch = torch.softmax(x.t(), axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

In [None]:
DEVICE = get_device()

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=["Triton", "Torch"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))

def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: triton_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=False)

## Triton Block Pointers (Advanced)

**Block pointers** are Triton's high-level abstraction for tiled memory access, eliminating error-prone manual pointer arithmetic.

**Core concept:** Instead of computing `ptr + offset` manually, block pointers encapsulate:
- **Where you are** in the tensor (offsets)
- **What tile** you're accessing (block_shape)
- **How to navigate** memory (strides)

---

### Example: 2D Tensor Block Pointer

```python
x_block_ptr = tl.make_block_ptr(
    x_ptr,                                    # Base address
    shape=(ROWS, D),                          # Full tensor: ROWS √ó D
    strides=(x_stride_row, x_stride_dim),     # Jump between rows/cols
    offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),  # Start at row tile, col 0
    block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),   # Tile size
    order=(1, 0),                             # Row-major layout
)
```

**Visual:**
```
Full tensor (ROWS √ó D):
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ [0,0]  ......  [0, D-1]         ‚îÇ ‚Üê row_tile_idx=0 loads ROWS_TILE_SIZE rows
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ [ROWS_TILE_SIZE, 0] ...         ‚îÇ ‚Üê row_tile_idx=1
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ  ...                            ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    ‚îî‚îÄ D_TILE_SIZE ‚îÄ‚îò  (tile width)
```

---

### Example: 1D Tensor Block Pointer

```python
weight_block_ptr = tl.make_block_ptr(
    weight_ptr,
    shape=(D,),                    # 1D vector
    strides=(weight_stride_dim,),  # Element spacing
    offsets=(0,),                  # Start at beginning
    block_shape=(D_TILE_SIZE,),    # Load D_TILE_SIZE elements
    order=(0,),                    # 1D ordering
)
```

---

### Usage Pattern: Tiled Computation

```python
# Initialize accumulator
output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)

# Loop over D dimension in tiles
for i in range(tl.cdiv(D, D_TILE_SIZE)):
    # Load current tiles with automatic boundary checking
    row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
    # Shape: (ROWS_TILE_SIZE, D_TILE_SIZE)
    
    weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
    # Shape: (D_TILE_SIZE,)
    
    # Compute weighted sum: sum over columns for each row
    output += tl.sum(row * weight[None, :], axis=1)  # Accumulate
    
    # Advance to next tile
    x_block_ptr = tl.advance(x_block_ptr, (0, D_TILE_SIZE))      # Move right
    weight_block_ptr = tl.advance(weight_block_ptr, (D_TILE_SIZE,))

# Write result
tl.store(output_block_ptr, output, boundary_check=(0,))
```

---

### Key Features

1. **Automatic boundary checking:** `boundary_check=(0, 1)` handles tiles that don't fit perfectly
   - Dimension 0 (rows): May not divide evenly by `ROWS_TILE_SIZE`
   - Dimension 1 (cols): May not divide evenly by `D_TILE_SIZE`

2. **Tiled computation:** Process large dimension in chunks, accumulating results

3. **Clean navigation:** `.advance()` moves to next tile without manual offset math

---

### Comparison: Manual vs. Block Pointers

```python
# OLD WAY (manual pointer arithmetic):
row_offsets = row_tile_idx * ROWS_TILE_SIZE + tl.arange(0, ROWS_TILE_SIZE)
col_offsets = tl.arange(0, D_TILE_SIZE)
x_ptrs = x_ptr + row_offsets[:, None] * x_stride_row + col_offsets[None, :] * x_stride_dim
mask = (row_offsets < ROWS)[:, None] & (col_offsets < D)[None, :]
row = tl.load(x_ptrs, mask=mask)

# NEW WAY (block pointers):
x_block_ptr = tl.make_block_ptr(...)
row = tl.load(x_block_ptr, boundary_check=(0, 1))
```

**Benefits:**
- Automatic bounds checking
- Cleaner code with `.advance()`
- Better compiler optimization
- Correct handling of non-contiguous tensors

---

The code below is adapted from the [Stanford CS336 course](https://github.com/stanford-cs336/assignment2-systems).

### Integrating Triton Kernels with PyTorch

**Example:** Weighted sum operation using block pointers for efficient tiled computation.

In [None]:
def weighted_sum(x, weight):
    # Here, assume that x has n-dim shape [..., D], and weight has 1D shape [D]
    return (weight * x).sum(axis=-1)



@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr,  # Input pointers
    output_ptr,  # Output pointer
    x_stride_row, x_stride_dim,  # Strides tell us how to move one element in each axis of a tensor
    weight_stride_dim,  # Likely 1
    output_stride_row,  # Likely 1
    ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,  # Tile shapes must be known at compile time
):
    # Each instance will compute the weighted sum of a tile of rows of x.
    # `tl.program_id` gives us a way to check which thread block we're running in
    row_tile_idx = tl.program_id(0)
    
    # Block pointers give us a way to select from an ND region of memory
    # and move our selection around.
    # The block pointer must know:
    # - The pointer to the first element of the tensor
    # - The overall shape of the tensor to handle out-of-bounds access
    # - The strides of each dimension to use the memory layout properly
    # - The ND coordinates of the starting block, i.e., "offsets"
    # - The block shape to use load/store at a time
    # - The order of the dimensions in memory from major to minor
    # axes (= np.argsort(strides)) for optimizations, especially useful on H100
    
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(ROWS, D),
        strides=(x_stride_row, x_stride_dim),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )
    
    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )
    
    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )
    
    # Initialize a buffer to write to
    output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)
    
    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        # Load the current block pointer
        # Since ROWS_TILE_SIZE might not divide ROWS, and D_TILE_SIZE might not divide D,
        # we need boundary checks for both dimensions
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")  # (ROWS_TILE_SIZE, D_TILE_SIZE)
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")  # (D_TILE_SIZE,)
        
        # Compute the weighted sum of the row.
        output += tl.sum(row * weight[None, :], axis=1)
        
        # Move the pointers to the next tile.
        # These are (rows, columns) coordinate deltas
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE))  # Move by D_TILE_SIZE in the last dimension
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))  # Move by D_TILE_SIZE
    
    # Write output to the output block pointer (a single scalar per row).
    # Since ROWS_TILE_SIZE might not divide ROWS, we need boundary checks
    tl.store(output_block_ptr, output, boundary_check=(0,))

In [None]:
def weighted_sum_triton(x: torch.Tensor, weight: torch.Tensor):
    D = x.shape[-1]
    output_dims = x.shape[:-1]   
    # Reshape input tensor to 2D
    
    x = rearrange(x, "... d -> (...) d")
    # Need to initialize empty result tensor. Note that these elements are not necessarily 0!
    y = torch.empty(x.shape[0], device=x.device)

    D_TILE_SIZE = triton.next_power_of_2(D) // 16  # Roughly 16 loops through the embedding dimension
    ROWS_TILE_SIZE = 16  # Each thread processes 16 batch elements at a time
        
    # Launch our kernel with n instances in our 1D grid.
    n_rows = y.numel()
    weighted_sum_fwd[(triton.cdiv(n_rows, ROWS_TILE_SIZE),)](
            x, weight,
            y,
            x.stride(0), x.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ROWS_TILE_SIZE, D_TILE_SIZE=D_TILE_SIZE,
        )
        
    return y.view(output_dims)

In [None]:
def check_equal3(f1, f2):
    x = torch.randn(64, 64, 2048, device=get_device())
    w = torch.randn(2048, device=get_device())
    y1 = f1(x,w)
    y2 = f2(x,w)
    assert torch.allclose(y1, y2, atol=1e-4)

In [None]:
check_equal3(weighted_sum,weighted_sum_triton)

In [None]:
class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight):
        # Cache x and weight to be used in the backward pass, when we
        # only receive the gradient wrt. the output tensor, and
        # need to compute the gradients wrt. x and weight.
        D, output_dims = x.shape[-1], x.shape[:-1]

        # Reshape input tensor to 2D
        x_reshaped = rearrange(x, "... d -> (...) d")
        ctx.output_dims = output_dims

        ctx.save_for_backward(x, weight)

        assert len(weight.shape) == 1 and weight.shape[0] == D, "Dimension mismatch"
        assert x.is_cuda and weight.is_cuda, "Expected CUDA tensors"
        assert (
            x_reshaped.is_contiguous()
        ), "Our pointer arithmetic will assume contiguous x"

        D_TILE_SIZE = (
            triton.next_power_of_2(D) // 16
        )  # Roughly 16 loops through the embedding dimension
        ROWS_TILE_SIZE = 16  # Each thread processes 16 batch elements at a time

        # Need to initialize empty result tensor. Note that these elements are not necessarily 0!
        y = torch.empty(x_reshaped.shape[0], device=x.device)

        # Launch our kernel with n instances in our 1D grid.
        n_rows = y.numel()
        weighted_sum_fwd[(triton.cdiv(n_rows, ROWS_TILE_SIZE),)](
            x_reshaped,
            weight,
            y,
            x_reshaped.stride(0),
            x_reshaped.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows,
            D=D,
            ROWS_TILE_SIZE=ROWS_TILE_SIZE,
            D_TILE_SIZE=D_TILE_SIZE,
        )

        return y.view(output_dims)

    # Here you should make a triton kernel for the backward instead of plain PyTorch!
    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        x, weight = ctx.saved_tensors

        # Reshape grad_output to match forward pass
        grad_output_flat = grad_output.reshape(-1)

        # Gradient wrt weight: sum over all samples
        # d/dw (w^T x) = x
        # So grad_weight = sum_i grad_output[i] * x[i]
        grad_weight = (grad_output_flat[:, None] * x).sum(dim=0)

        # Gradient wrt x: broadcast weight
        # d/dx (w^T x) = w
        # So grad_x = grad_output * w
        grad_x = grad_output_flat[:, None] * weight[None, :]

        # Reshape grad_x back to original shape
        grad_x = grad_x.view(*ctx.output_dims, -1)

        return grad_x, grad_weight

In [None]:
class LinearRegressionTriton(torch.nn.Module):
    """
    Linear regression using the custom Triton weighted sum kernel.

    Model: y = w^T x + b
    """

    def __init__(self, input_dim: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(input_dim, device="cuda") * 0.01)
        self.bias = torch.nn.Parameter(torch.zeros(1, device="cuda"))

    def forward(self, x):
        # x: (batch_size, input_dim)
        # Use our custom weighted sum function
        return WeightedSumFunc.apply(x, self.weight) + self.bias


def generate_regression_data(n_samples=1000, input_dim=128, noise_std=0.1, seed=42):
    """
    Generate synthetic linear regression data.

    Returns:
        X: (n_samples, input_dim) feature matrix
        y: (n_samples,) continuous target values
        true_weight: (input_dim,) true weight vector used for generation
        true_bias: (1,) true bias value used for generation
    """
    torch.manual_seed(seed)

    # Generate random features
    X = torch.randn(n_samples, input_dim, device="cuda")

    # Create true weights for data generation
    true_weight = torch.randn(input_dim, device="cuda")
    true_bias = torch.randn(1, device="cuda")

    # Generate target values: y = w^T x + b + noise
    y = X @ true_weight + true_bias

    # Add Gaussian noise
    noise = torch.randn(n_samples, device="cuda") * noise_std
    y = y + noise

    return X, y, true_weight, true_bias


def train_linear_regression(
    model, X_train, y_train, X_val, y_val, epochs=100, lr=0.01, batch_size=64
):
    """
    Train the linear regression model.

    Args:
        model: LinearRegressionTriton model
        X_train, y_train: Training data
        X_val, y_val: Validation data
        epochs: Number of training epochs
        lr: Learning rate
        batch_size: Batch size for training

    Returns:
        train_losses: List of training losses (MSE) per epoch
        val_losses: List of validation losses (MSE) per epoch
        val_r2_scores: List of validation R¬≤ scores per epoch
    """
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    train_losses = []
    val_losses = []
    val_r2_scores = []

    n_batches = (len(X_train) + batch_size - 1) // batch_size

    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0.0

        # Shuffle data
        perm = torch.randperm(len(X_train), device="cuda")
        X_train_shuffled = X_train[perm]
        y_train_shuffled = y_train[perm]

        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min(start_idx + batch_size, len(X_train))

            X_batch = X_train_shuffled[start_idx:end_idx]
            y_batch = y_train_shuffled[start_idx:end_idx]

            # Forward pass
            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / n_batches
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        with torch.no_grad():
            y_val_pred = model(X_val)
            val_loss = criterion(y_val_pred, y_val).item()
            val_losses.append(val_loss)

            # Calculate R¬≤ score
            ss_res = ((y_val - y_val_pred) ** 2).sum()
            ss_tot = ((y_val - y_val.mean()) ** 2).sum()
            r2_score = 1 - (ss_res / ss_tot)
            val_r2_scores.append(r2_score.item())

        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(
                f"Epoch {epoch+1}/{epochs}: "
                f"Train Loss = {avg_train_loss:.4f}, "
                f"Val Loss = {val_loss:.4f}, "
                f"Val R¬≤ = {r2_score.item():.4f}"
            )

    return train_losses, val_losses, val_r2_scores



In [None]:
def main():
    """
    Main function to demonstrate linear regression with custom Triton kernel.
    """
    print("=" * 80)
    print("Linear Regression with Custom Triton Weighted Sum Kernel")
    print("=" * 80)

    # Check CUDA availability
    if not torch.cuda.is_available():
        print("ERROR: CUDA is not available. This example requires a GPU.")
        return

    # Hyperparameters
    n_train = 800
    n_val = 200
    input_dim = 128
    epochs = 100
    lr = 0.01
    batch_size = 64
    noise_std = 0.1

    print(f"\nDataset configuration:")
    print(f"  Training samples: {n_train}")
    print(f"  Validation samples: {n_val}")
    print(f"  Input dimension: {input_dim}")
    print(f"  Noise std: {noise_std}")
    print(f"\nTraining configuration:")
    print(f"  Epochs: {epochs}")
    print(f"  Learning rate: {lr}")
    print(f"  Batch size: {batch_size}")
    print()

    # Generate data
    print("Generating synthetic linear regression data...")
    X, y, true_weight, true_bias = generate_regression_data(
        n_samples=n_train + n_val, input_dim=input_dim, noise_std=noise_std
    )

    # Split into train and validation
    X_train, y_train = X[:n_train], y[:n_train]
    X_val, y_val = X[n_train:], y[n_train:]

    print(f"Training set: X shape = {X_train.shape}, y shape = {y_train.shape}")
    print(f"Validation set: X shape = {X_val.shape}, y shape = {y_val.shape}")
    print(f"\nTrue parameters:")
    print(f"  True weight norm: {true_weight.norm().item():.6f}")
    print(f"  True bias: {true_bias.item():.6f}")
    print()

    # Initialize model
    print("Initializing linear regression model with Triton kernel...")
    model = LinearRegressionTriton(input_dim=input_dim)
    model = torch.compile(model)
    print(
        f"Model parameters: weight shape = {model.weight.shape}, bias shape = {model.bias.shape}"
    )
    print()

    # Train model
    print("Starting training...\n")
    train_losses, val_losses, val_r2_scores = train_linear_regression(
        model,
        X_train,
        y_train,
        X_val,
        y_val,
        epochs=epochs,
        lr=lr,
        batch_size=batch_size,
    )

    # Final evaluation
    print("\n" + "=" * 80)
    print("Training Complete!")
    print("=" * 80)
    print(f"Final Training Loss (MSE): {train_losses[-1]:.4f}")
    print(f"Final Validation Loss (MSE): {val_losses[-1]:.4f}")
    print()

    # Compare learned parameters with true parameters
    print("=" * 80)
    print("Parameter Comparison: Learned vs True")
    print("=" * 80)

    learned_weight = model.weight.data
    learned_bias = model.bias.data

    # Compute various comparison metrics
    weight_diff = learned_weight - true_weight
    weight_mse = (weight_diff**2).mean().item()
    weight_mae = weight_diff.abs().mean().item()
    

    bias_diff = (learned_bias - true_bias).abs().item()


    print(f"\nWeight Statistics:")
    print(f"  True weight norm:        {true_weight.norm().item():.6f}")
    print(f"  Learned weight norm:     {learned_weight.norm().item():.6f}")
    print(f"  Weight MSE:              {weight_mse:.6f}")
    print(f"  Weight MAE:              {weight_mae:.6f}")
    
    print(f"\nBias Statistics:")
    print(f"  True bias:               {true_bias.item():.6f}")
    print(f"  Learned bias:            {learned_bias.item():.6f}")
    print(f"  Bias absolute difference: {bias_diff:.6f}")
    print()

    print("=" * 80)
    print("Example completed successfully!")
    print("=" * 80)

In [None]:
main()