In [None]:
from setup_triton import setup_triton

# TRITON_INTERPRET=1 uses a python interpreter instead of running on the GPU. 
# This menas that uou can insert Python breakpoints to debug your kernel code! 
setup_triton(use_interpreter=True)

# Triton Puzzle 5: Layer Normalization

Welcome to the fifth Triton puzzle! Layer Normalization is a crucial component in modern deep learning, especially in transformers. This puzzle introduces parallel reduction patterns and online algorithms.


## Mathematical Background

Layer Normalization normalizes inputs across the feature dimension:

Given input $\mathbf{x} \in \mathbb{R}^{N \times D}$ (batch size N, feature dimension D):

$\text{LayerNorm}(\mathbf{x}_i) = \gamma \odot \frac{\mathbf{x}_i - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} + \beta$

Where for each sample $i$:
- $\mu_i = \frac{1}{D} \sum_{j=1}^{D} x_{ij}$ (mean across features)
- $\sigma_i^2 = \frac{1}{D} \sum_{j=1}^{D} (x_{ij} - \mu_i)^2$ (variance across features)
- $\gamma, \beta \in \mathbb{R}^D$ are learned scale and shift parameters
- $\epsilon$ is a small constant for numerical stability


### The Reduction Challenge

Unlike our previous operations, LayerNorm requires **reduction** across the feature dimension:
- Each thread needs information from ALL features to compute statistics
- Can't process each element independently
- Need efficient parallel reduction algorithms


In [None]:
import torch
import triton
import triton.language as tl
import numpy as np
from IPython.display import display, Image

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# Set random seed
torch.manual_seed(42)

## Implementation 1: Naive PyTorch (Two-pass)

First, let's see a straightforward two-pass implementation:

In [None]:
def layernorm_naive(x, weight, bias, eps=1e-5):
    """
    Naive two-pass implementation of LayerNorm.
    First pass: compute mean
    Second pass: compute variance
    Third pass: normalize
    """
    # Assume x is (N, D) where we normalize over D
    mean = x.mean(dim=-1, keepdim=True)  # (N, 1)
    var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)  # (N, 1)
    std = torch.sqrt(var + eps)
    # Normalize
    x_norm = (x - mean) / std
    # Scale and shift
    return weight * x_norm + bias

## Implementation 2: PyTorch Built-in

PyTorch's built-in uses optimized CUDA kernels:


In [None]:
def layernorm_pytorch(x, weight, bias, eps=1e-5):
    """PyTorch's built-in LayerNorm."""
    return torch.nn.functional.layer_norm(x, x.shape[-1:], weight, bias, eps)

## Implementation 3: PyTorch Compiled

Let's try to compile a naive version:

In [None]:
@torch.compile
def layernorm_compiled(x, weight, bias, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)  # (N, 1)
    var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)  # (N, 1)
    std = torch.sqrt(var + eps)
    x_norm = (x - mean) / std
    return weight * x_norm + bias

## Key Concepts for This Puzzle

### 1. Parallel Reduction in Triton

Triton provides `tl.reduce` for efficient reduction operations:

```python
# Sum reduction along axis 0
sum_val = tl.sum(data, axis=0)

# You can also use tl.reduce with custom operations
mean_val = tl.sum(data, axis=0) / num_elements
```

### 2. Single-Pass Algorithm

- Pass 1: Compute mean
- Pass 2: Compute variance using mean
- Pass 3: Normalize


### 3. Block Size Considerations

- Each program handles one sample (row)
- Block size should be large enough to process all features
- Need to handle cases where D > BLOCK_SIZE

## Implementation 4: Triton Kernel (Puzzle)

Now, let's implement LayerNorm in Triton!

### Your Task:
1. Each program handles one sample (row)
2. Load features in blocks
3. Compute mean using `tl.sum`
4. Compute variance (in the same pass if possible)
5. Normalize and apply scale/shift
6. Handle cases where feature dimension > BLOCK_SIZE

In [None]:
@triton.jit
def layernorm_kernel(
    x_ptr, y_ptr, weight_ptr, bias_ptr,
    N, D,  # N = batch size, D = feature dimension
    eps: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
):
    """
    LayerNorm kernel where each program handles one sample.
    
    Key challenges:
    - Compute mean and variance across D dimension
    - Handle D > BLOCK_SIZE by looping
    - Maintain numerical stability
    """
    # YOUR IMPLEMENTATION GOES HERE
    # Hints:
    # 1. Use tl.program_id(0) to get which sample this program handles
    # 2. Loop over features in blocks of BLOCK_SIZE
    # 3. Accumulate sum and sum of squares for mean/variance
    # 4. After computing statistics, loop again to normalize
    # 5. Don't forget to apply weight and bias!
    pass


def layernorm_triton(x, weight, bias, eps=1e-5):
    """Wrapper for the Triton LayerNorm kernel."""
    assert x.is_contiguous()
    assert weight.is_contiguous() 
    assert bias.is_contiguous()
    
    N, D = x.shape
    
    # Allocate output
    y = torch.empty_like(x)
    
    # Choose block size (must be power of 2)
    BLOCK_SIZE = triton.next_power_of_2(min(D, 1024))
    
    # Launch grid: one program per sample
    grid = (N,)
    
    # Launch kernel
    layernorm_kernel[grid](
        x, y, weight, bias,
        N, D,
        eps,
        BLOCK_SIZE
    )
    
    return y

## Solution 🧙

You shall not pass!

In [None]:
# Our solution goes here

> **Note:** For better performance, there are more efficient algorithms (Welford's algorithm), which computes mean and variance in a single-pass. But, we won't cover here because it is out of the scope of this tutorial.

## Testing Correctness

Let's verify our implementation:

In [None]:
def test_correctness(N=32, D=256, eps=1e-5, atol=1e-3, rtol=1e-3):
    """Test if Triton implementation matches PyTorch."""
    torch.manual_seed(42)
    
    # Create test inputs
    x = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
    weight = torch.randn(D, device=DEVICE, dtype=torch.float32)
    bias = torch.randn(D, device=DEVICE, dtype=torch.float32)
    
    # Compute with PyTorch
    expected = layernorm_pytorch(x, weight, bias, eps)
    
    # Compute with Triton
    actual = layernorm_triton(x, weight, bias, eps)
    
    try:
        torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
        print(f"✅ Test PASSED! Results match within tolerance.")
        print(f"   Shape: ({N}, {D})")
        print(f"   Max absolute difference: {(actual - expected).abs().max().item():.2e}")
        
        # Test edge cases
        test_cases = [
            (1, 1024),    # Single sample
            (128, 64),    # Small features
            (64, 2048),   # Large features
        ]
        
        for n, d in test_cases:
            x_test = torch.randn(n, d, device=DEVICE, dtype=torch.float32)
            w_test = torch.randn(d, device=DEVICE, dtype=torch.float32)
            b_test = torch.randn(d, device=DEVICE, dtype=torch.float32)
            
            expected_test = layernorm_pytorch(x_test, w_test, b_test, eps)
            actual_test = layernorm_triton(x_test, w_test, b_test, eps)
            
            torch.testing.assert_close(actual_test, expected_test, atol=atol, rtol=rtol)
            print(f"✅ Size ({n}, {d}) PASSED!")
        
        return True
        
    except AssertionError as e:
        print(f"❌ Test FAILED!")
        print(f"   Error: {e}")
        return False

# Run tests
test_passed = test_correctness()

# Display congrats message
if test_passed:
    print("\n🎉 Congratulations! Your implementation is correct!")
    display(Image("figs/success.gif", width=256, height=256))

## Summary

You've successfully implemented LayerNorm, mastering several advanced concepts!

### Key Concepts Mastered:

1. **Parallel Reduction**: Computing statistics across a dimension
2. **Multi-pass vs Single-pass**: Trade-offs between simplicity and efficiency
3. **Numerical Stability**: Handling variance computation carefully
4. **Variable-length Processing**: Using masks for features > BLOCK_SIZE
5. **Memory Efficiency**: Minimizing passes over data


### Next Steps:

Ready for CrossEntropy Loss? The final puzzle introduces:
- Fusing the output layer
- Dealing with high dimensionality
- Integrating Triton & PyTorch

Let's continue with our Triton journey!

<img src="figs/sardine-challenge.png" width="512" />

---

## Benchmarking (GPU only)

Now let's benchmark the implementations:

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['D'],  # Feature dimension
        x_vals=[128, 256, 512, 768, 1024, 2048, 4096],
        line_arg='provider',
        line_vals=['pytorch', 'compile', 'triton'],
        line_names=['PyTorch', 'torch.compile', 'Triton'],
        styles=[('green', '-'), ('red', '--'), ('blue', '-.')],
        ylabel='GB/s',
        plot_name='layernorm-performance',
        args={'N': 1024},  # Batch size
    )
)
def benchmark(D, N, provider):
    """Benchmark LayerNorm."""
    x = torch.randn(N, D, device=DEVICE, dtype=torch.float32)
    weight = torch.randn(D, device=DEVICE, dtype=torch.float32)
    bias = torch.randn(D, device=DEVICE, dtype=torch.float32)
    
    quantiles = [0.5, 0.05, 0.95]
    
    if provider == 'pytorch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: layernorm_pytorch(x, weight, bias), quantiles=quantiles
        )
    elif provider == 'compile':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: layernorm_compiled(x, weight, bias), quantiles=quantiles
        )
    elif provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: layernorm_triton(x, weight, bias), quantiles=quantiles
        )
    
    # Calculate bandwidth
    bytes_moved = x.numel() * x.element_size() * 2  # Read + write
    bytes_moved += weight.numel() * weight.element_size() * N  # Weight read N times
    bytes_moved += bias.numel() * bias.element_size() * N  # Bias read N times
    
    gb_per_s = lambda ms: bytes_moved / ms / 1e6
    
    return gb_per_s(ms), gb_per_s(max_ms), gb_per_s(min_ms)

print("Running benchmarks...")
results = benchmark.run(show_plots=True, print_data=True, return_df=True, save_path='')

## Speedup?

In [None]:
# Check if Triton is faster than PyTorch
avg_pytorch = results['PyTorch'].mean()
avg_triton = results['Triton'].mean()
speedup = avg_triton / avg_pytorch

if speedup > 1.0:
    print(f"\n🚀 Awesome! Triton is {speedup:.2f}x faster than PyTorch!")
    display(Image("figs/gpu.gif", width=400, height=256))
else:
    print(f"\n🐌 Not quite yet! Triton implementation is {speedup:.2f}x slower than PyTorch!")