# The Misadventures of L2 Norm Triton Kernel

This notebook details the iterative process, challenges, and solutions encountered while implementing a GPU-accelerated Euclidean norm (`nrm2`) routine using Triton kernels. This notebook illustrates common pitfalls in parallel programming and the importance of understanding how to ensure numerical correctness on GPUs.

The goal was to efficiently calculate $||x||_2 = \sqrt{\sum_{i=1}^{N} x_i^2}$ for a given input vector $x$.


## Prerequisites and Setup

Before we begin, ensure you have the necessary libraries installed and a CUDA-enabled GPU available.


In [17]:
import torch
import triton
import triton.language as tl
import numpy as np
import math

# check for accelerator
if not torch.accelerator.is_available():
    raise RuntimeError("CUDA is not available. This notebook requires a GPU.")
accelerator = torch.accelerator.current_accelerator()
print(f"Detected accelerator: {accelerator}")

# input tensor
SIZE = 98432
x_torch = torch.randn(SIZE, device=accelerator, dtype=torch.float32)

# fixed block size for demonstration (only autotuning in atomic add example)
BLOCK_SIZE = 256

print(f"\nInput tensor size: {SIZE} elements")
print(f"First 10 elements of input tensor: {x_torch[:10]}")
print(f"Using fixed BLOCK_SIZE for non-autotuned kernels: {BLOCK_SIZE}")


Detected accelerator: cuda

Input tensor size: 98432 elements
First 10 elements of input tensor: tensor([-1.3399,  0.9823, -0.1188,  0.0460, -0.1024,  1.7807, -0.6536, -0.7665,
        -0.2846, -0.3786], device='cuda:0')
Using fixed BLOCK_SIZE for non-autotuned kernels: 256


## Core Challenge: Ensuring Correct Parallel Summation for L2 Norm

The `nrm2` calculation requires squaring each element, summing all squared elements, and finally taking the square root. The most complex part from a GPU programming perspective is the highly parallel summation of potentially millions of squared values into a single scalar result, known as a **global reduction**.


## Attempt 1: Single-Kernel Approach

My initial strategy aimed for a single Triton kernel to perform the entire squaring and summation.


### Misadventure: `tl.sum()` for Global Reduction

**Concept:** Leverage Triton's `tl.sum()` directly on the squared elements, hoping it performs a complete global reduction across all threads and blocks.


In [12]:
@triton.jit
def global_sum_misadventure(
    x_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    
    # sum of squares of the elements in the block
    partial = tl.sum(x * x) 

    # THE MISADVENTURE: tl.sum() here operates only on the 'squares' 
    # loaded by THIS thread block. It DOES NOT sum across all blocks 
    # in the grid. If you launch many blocks, each calculates its own 
    # block_sum_of_squares, but they are not aggregated into a single global result.
    total = tl.sum(partial[None], axis=0) 

    # sqrt only one block's sum
    output = tl.sqrt(total) 
    tl.store(output_ptr, output) 


# test to show output is not a single global sum
num_blocks = triton.cdiv(SIZE, BLOCK_SIZE)
results = torch.zeros(1, device=accelerator, dtype=torch.float32)

try:
    global_sum_misadventure[(num_blocks,)](
        x_torch, results, SIZE, BLOCK_SIZE=BLOCK_SIZE
    )
    print(f"Kernel output: {results} (nrm2 of only one block because of misadventure)")
    print(f"Correct output: {torch.linalg.norm(x_torch)}")
    print("Number of blocks: ", num_blocks)
    print("Lesson: `tl.sum()` is a local reduction primitive, not a global one across all blocks, so this does not output the desired L2 norm.")
    
except Exception as e:
    print(f"\nAttempt 3.1 failed as expected for global reduction: {e}")


Kernel output: tensor([16.3524], device='cuda:0') (nrm2 of only one block because of misadventure)
Correct output: 313.6213073730469
Number of blocks:  385
Lesson: `tl.sum()` is a local reduction primitive, not a global one across all blocks, so this does not output the desired L2 norm.


### Misadventure: `tl.atomic_add()` for Global Accumulation

**Concept:** Shifted to using `tl.atomic_add()` within a single kernel. Each thread block computes its partial sum of squares and then atomically adds that partial sum to a single, shared global memory location that would eventually hold the total sum.


In [13]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 64}),
        triton.Config({'BLOCK_SIZE': 128}),
        triton.Config({'BLOCK_SIZE': 256}),
        triton.Config({'BLOCK_SIZE': 512}),
    ],
    key=['n_elements'],
)
@triton.jit
def atomic_sum_misadventure(
    x_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    
    partial = tl.sum(x * x) 
    
    # THE MISADVENTURE: The autotuner will test multiple configurations
    # and accumulate results from ALL test runs into the same global_sum_ptr!
    # This leads to massively inflated and numerically INCORRECT outputs.
    # This highlights the danger of blindly trusting a system's correctness when it interacts 
    # with mutable global state across its multiple evaluation runs.
    # (Referenced: https://github.com/triton-lang/triton/issues/6524)
    tl.atomic_add(output_ptr, partial) 


# test
print("\nDemonstrating `tl.atomic_add` accumulation issue with autotuner")

# show the autotuned accumulation problem
global_result = torch.zeros(1, device=accelerator, dtype=torch.float32)
max_blocks = triton.cdiv(SIZE, 64)

try:
    atomic_sum_misadventure[(max_blocks,)](
        x_torch, global_result, SIZE
    )    
    true_sum_of_squares = torch.sum(x_torch * x_torch).item()
    print(f"True sum of squares (torch.sum(x*x)): {true_sum_of_squares:.6f}")
    print(f"Autotuned result: {global_result.item():.6f}")
    print(f"Autotuned result is {global_result.item() / true_sum_of_squares:.2f}x the true sum of squares.")
    print("The autotuner tested multiple configurations and accumulated ALL results!")
    
except Exception as e:
    print(f"✗ Autotuned version failed: {e}")

print("\nLesson: While `atomic_add` ensures concurrent writes are safe, it does not reset state.")
print("Using it with the autotuner (which tests multiple configurations) leads to")
print("fundamentally incorrect, accumulated sums across ALL autotuner trials.")



Demonstrating `tl.atomic_add` accumulation issue with autotuner
True sum of squares (torch.sum(x*x)): 98358.335938
Autotuned result: 124440184.000000
Autotuned result is 1265.17x the true sum of squares.
The autotuner tested multiple configurations and accumulated ALL results!

Lesson: While `atomic_add` ensures concurrent writes are safe, it does not reset state.
Using it with the autotuner (which tests multiple configurations) leads to
fundamentally incorrect, accumulated sums across ALL autotuner trials.


## Attempt 2: The Two-Kernel Approach

To overcome the challenges of single-kernel global reductions, we adopted a multi-stage approach, splitting the problem into two distinct kernels.

### `nrm2_partial` Kernel: Distributed Sum of Squares

This kernel's role is to perform the element-wise squaring and sum those squares *within each thread block*. Each block then stores its partial sum to a **unique index** in a temporary global memory array. This avoids atomic contention in the first stage and ensures each block's result is isolated.


In [14]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_PARTIAL': 64}),
        triton.Config({'BLOCK_SIZE_PARTIAL': 128}),
        triton.Config({'BLOCK_SIZE_PARTIAL': 256}),
        triton.Config({'BLOCK_SIZE_PARTIAL': 512}),
        triton.Config({'BLOCK_SIZE_PARTIAL': 1024}),
    ],
    key=['n_elements'],
)
@triton.jit
def nrm2_partial(
    x_ptr,
    partial_ptr,
    n_elements,
    BLOCK_SIZE_PARTIAL: tl.constexpr
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE_PARTIAL
    offsets = block_start + tl.arange(0, BLOCK_SIZE_PARTIAL)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

    # sum within this block (intra-block reduction)
    sum_of_squares = tl.sum(x * x) 
    
    # no atomic needed here and safe with autotuning, as each program writes to a distinct memory address.
    tl.store(partial_ptr + pid, sum_of_squares)

### `nrm2_final` Kernel: Final Sequential Summation

Instead of parallelizing the final sum with atomics, this kernel is launched with **only one thread block**. This single thread sequentially iterates through all `partial_sums` from the first kernel, accumulating them into a local accumulator variable.


In [15]:
@triton.jit
def nrm2_final(
    partial_ptr, 
    final_ptr,
    n_partial, 
):    
    total = tl.zeros((), dtype=tl.float32)
    
    # process partial sums sequentially
    for i in range(n_partial):
        partial_val = tl.load(partial_ptr + i)
        total += partial_val
        
    tl.store(final_ptr, total)

### `nrm2` Function

This Python function facilitates the two kernel launches and performs the final square root operation.


In [16]:
def nrm2(x: torch.Tensor):
    output = torch.zeros((1,), device=x.device)
    assert x.device == output.device, "Input and output tensors must be on the same device."
    n_elements = x.numel()

    # calculate the maximum number of blocks that might be needed
    min_block_size = 64  
    max_partial_sums = triton.cdiv(n_elements, min_block_size)
    partial_sums = torch.zeros(max_partial_sums, device=x.device, dtype=torch.float32)

    # the autotuner will determine the optimal grid size based on the selected BLOCK_SIZE_PARTIAL
    nrm2_partial[(max_partial_sums,)](
        x,
        partial_sums,
        n_elements
    )
    
    nrm2_final[(1,)](
        partial_sums,
        output,
        max_partial_sums
    )

    # square root the sum of partial sums
    result = torch.sqrt(output)
    return result

true_norm = torch.norm(x_torch, p=2).item()
print(f"True L2 Norm (torch.linalg.norm): {true_norm:.6f}")

result_nrm2_custom = nrm2(x_torch)
print(f"Triton nrm2 custom kernel result: {result_nrm2_custom.item():.6f}")

assert math.isclose(result_nrm2_custom.item(), true_norm, rel_tol=1e-5), "Triton nrm2 result does not match torch.norm!"
print("Correct!")

True L2 Norm (torch.linalg.norm): 313.621307
Triton nrm2 custom kernel result: 313.621368
Correct!


## Key Takeaways

This kernel development process highlights valuable lessons in GPU kernel development specific to reduction patterns, focusing on achieving **numerical correctness** when designing parallel algorithms:

* **Global Reduction Complexity:** Achieving correct global reductions often requires multi-stage approaches (e.g., block-wise partials, then a final aggregation). A single kernel trying to sum everything is difficult to get right.

* **Atomics for Correctness, but Caution for Global State:** `tl.atomic_add` ensures correct concurrent updates to a single memory location. However, when used with repeated kernel calls in autotuning, it **accumulates results across calls, leading to fundamentally incorrect final sums**.

* **Trusting Autotuner Correctness:** Do not implicitly trust the autotuner if your kernel modifies a global state that is not reset between its internal trials.

* **Sequential Final Stages for Guaranteed Accuracy:** For the final aggregation step of a reduction, a single-threaded sequential sum can be a simple way to guarantee numerical correctness, but it comes at a cost on performance.

* **Verification is Paramount:** Always verify custom kernel results against known correct implementations (like `torch.norm`) to ensure numerical accuracy.
