# üß¨ KernelBench Triton Evolution

**Evolving high-performance GPU kernels using LLM-driven mutation and selection**

This notebook demonstrates evolving Triton softmax kernels that outperform PyTorch's baseline.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/anthropics/agentic-evolve/blob/main/showcase/kernelbench-triton-evolution/KernelBench_Evolution.ipynb)

---

## Setup

**Important**: Make sure you're using a GPU runtime!
- Go to `Runtime` ‚Üí `Change runtime type` ‚Üí Select `T4 GPU`

In [None]:
# Verify GPU is available
import torch

if not torch.cuda.is_available():
    raise RuntimeError("‚ùå No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
print(f"‚úÖ CUDA: {torch.version.cuda}")
print(f"‚úÖ PyTorch: {torch.__version__}")
!nvidia-smi --query-gpu=memory.total,memory.free --format=csv

In [None]:
# Install Triton
!pip install -q triton

import triton
print(f"‚úÖ Triton: {triton.__version__}")

---

## Baseline: PyTorch Softmax

First, let's establish our baseline performance using PyTorch's optimized softmax.

In [None]:
import torch
import torch.nn.functional as F
import time
import triton
import triton.language as tl

def benchmark(fn, x, name, warmup=10, iters=100):
    """Benchmark a function and return median time in ms."""
    # Warmup
    for _ in range(warmup):
        fn(x)
    torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(iters):
        start = time.perf_counter()
        fn(x)
        torch.cuda.synchronize()
        times.append((time.perf_counter() - start) * 1000)
    
    times = sorted(times)
    median = times[len(times) // 2]
    print(f"{name}: {median:.4f}ms")
    return median

# Test shapes (typical transformer workloads)
SHAPES = [
    (64, 2048, "64x2048 (typical)"),
    (128, 4096, "128x4096 (long context)"),
]

print("=" * 50)
print("BASELINE: PyTorch F.softmax")
print("=" * 50)

baselines = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    t = benchmark(lambda x: F.softmax(x, dim=-1), x, name)
    baselines[f"{batch}x{seq}"] = t

print("\nBaselines recorded. Now let's evolve faster kernels!")

---

## Generation 0: Starter Kernel

Our initial population seed - based on the [Triton fused softmax tutorial](https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html).

This is a correct but unoptimized kernel that serves as our evolution starting point.

In [None]:
@triton.jit
def softmax_kernel_gen0(
    input_ptr, output_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Gen 0: Basic fused softmax from Triton tutorial.
    
    - One program per row
    - Standard max-subtract-exp-sum-divide
    """
    row_idx = tl.program_id(0)
    
    # Compute row pointers
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    
    # Load with masking
    mask = col_offsets < n_cols
    row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
    
    # Numerically stable softmax
    row_max = tl.max(row, axis=0)
    row_stable = row - row_max
    numerator = tl.exp(row_stable)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    # Store result
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)


def softmax_gen0(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    y = torch.empty_like(x)
    softmax_kernel_gen0[(n_rows,)](
        x, y,
        x.stride(0), y.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y


# Verify correctness
x_test = torch.randn(32, 1024, device='cuda', dtype=torch.float32)
triton_out = softmax_gen0(x_test)
torch_out = F.softmax(x_test, dim=-1)
max_diff = (triton_out - torch_out).abs().max().item()

print(f"Gen 0 Correctness: max_diff = {max_diff:.2e}", end=" ")
print("‚úÖ" if max_diff < 1e-5 else "‚ùå")

# Benchmark
print("\n" + "=" * 50)
print("GENERATION 0: Starter Kernel")
print("=" * 50)

gen0_results = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    t = benchmark(softmax_gen0, x, name)
    speedup = baselines[f"{batch}x{seq}"] / t
    gen0_results[f"{batch}x{seq}"] = {"time": t, "speedup": speedup}
    print(f"  ‚Üí Speedup vs PyTorch: {speedup:.2f}x")

---

## Generation 1: Larger Block Size

**Mutation**: Increase BLOCK_SIZE to improve memory coalescing and reduce kernel launch overhead.

**Hypothesis**: Larger blocks process more data per thread, reducing total thread count and improving cache utilization.

In [None]:
@triton.jit
def softmax_kernel_gen1(
    input_ptr, output_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Gen 1: Same as Gen 0, but we'll use larger BLOCK_SIZE.
    
    Mutation: Force BLOCK_SIZE to be at least 1024 for better occupancy.
    """
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    
    mask = col_offsets < n_cols
    row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
    
    row_max = tl.max(row, axis=0)
    row_stable = row - row_max
    numerator = tl.exp(row_stable)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)


def softmax_gen1(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    # MUTATION: Use larger block size, minimum 1024
    BLOCK_SIZE = max(1024, triton.next_power_of_2(n_cols))
    y = torch.empty_like(x)
    softmax_kernel_gen1[(n_rows,)](
        x, y,
        x.stride(0), y.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y


# Verify correctness
triton_out = softmax_gen1(x_test)
max_diff = (triton_out - torch_out).abs().max().item()
print(f"Gen 1 Correctness: max_diff = {max_diff:.2e}", end=" ")
print("‚úÖ" if max_diff < 1e-5 else "‚ùå")

# Benchmark
print("\n" + "=" * 50)
print("GENERATION 1: Larger Block Size")
print("=" * 50)

gen1_results = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    t = benchmark(softmax_gen1, x, name)
    speedup = baselines[f"{batch}x{seq}"] / t
    gen1_results[f"{batch}x{seq}"] = {"time": t, "speedup": speedup}
    print(f"  ‚Üí Speedup vs PyTorch: {speedup:.2f}x")

---

## Generation 2: num_warps Tuning

**Mutation**: Add `num_warps` parameter to control parallelism within each thread block.

**Hypothesis**: More warps can hide memory latency, but too many can cause register pressure.

In [None]:
@triton.jit
def softmax_kernel_gen2(
    input_ptr, output_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Gen 2: Added num_warps tuning in the launcher.
    
    Kernel is identical to Gen 1, but we tune num_warps.
    """
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    
    mask = col_offsets < n_cols
    row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
    
    row_max = tl.max(row, axis=0)
    row_stable = row - row_max
    numerator = tl.exp(row_stable)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)


def softmax_gen2(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    BLOCK_SIZE = max(1024, triton.next_power_of_2(n_cols))
    
    # MUTATION: Tune num_warps based on block size
    # Heuristic: 4 warps for small blocks, 8 for larger
    num_warps = 8 if BLOCK_SIZE >= 2048 else 4
    
    y = torch.empty_like(x)
    softmax_kernel_gen2[(n_rows,)](
        x, y,
        x.stride(0), y.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return y


# Verify correctness
triton_out = softmax_gen2(x_test)
max_diff = (triton_out - torch_out).abs().max().item()
print(f"Gen 2 Correctness: max_diff = {max_diff:.2e}", end=" ")
print("‚úÖ" if max_diff < 1e-5 else "‚ùå")

# Benchmark
print("\n" + "=" * 50)
print("GENERATION 2: num_warps Tuning")
print("=" * 50)

gen2_results = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    t = benchmark(softmax_gen2, x, name)
    speedup = baselines[f"{batch}x{seq}"] / t
    gen2_results[f"{batch}x{seq}"] = {"time": t, "speedup": speedup}
    print(f"  ‚Üí Speedup vs PyTorch: {speedup:.2f}x")

---

## Generation 3: Autotuning

**Mutation**: Use Triton's `@triton.autotune` to automatically find best configuration.

**Hypothesis**: Let the compiler search for optimal BLOCK_SIZE and num_warps combinations.

In [None]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 4096}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 4096}, num_warps=16),
        triton.Config({'BLOCK_SIZE': 8192}, num_warps=16),
    ],
    key=['n_cols'],
)
@triton.jit
def softmax_kernel_gen3(
    input_ptr, output_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Gen 3: Autotuned kernel.
    
    Triton automatically selects best config per input size.
    """
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    
    mask = col_offsets < n_cols
    row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
    
    row_max = tl.max(row, axis=0)
    row_stable = row - row_max
    numerator = tl.exp(row_stable)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=mask)


def softmax_gen3(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    y = torch.empty_like(x)
    softmax_kernel_gen3[(n_rows,)](
        x, y,
        x.stride(0), y.stride(0),
        n_cols,
    )
    return y


# Verify correctness
triton_out = softmax_gen3(x_test)
max_diff = (triton_out - torch_out).abs().max().item()
print(f"Gen 3 Correctness: max_diff = {max_diff:.2e}", end=" ")
print("‚úÖ" if max_diff < 1e-5 else "‚ùå")

# Benchmark (first run triggers autotuning)
print("\n" + "=" * 50)
print("GENERATION 3: Autotuned (first run includes tuning)")
print("=" * 50)

gen3_results = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    # Extra warmup for autotuning
    for _ in range(5):
        softmax_gen3(x)
    torch.cuda.synchronize()
    t = benchmark(softmax_gen3, x, name)
    speedup = baselines[f"{batch}x{seq}"] / t
    gen3_results[f"{batch}x{seq}"] = {"time": t, "speedup": speedup}
    print(f"  ‚Üí Speedup vs PyTorch: {speedup:.2f}x")

---

## Generation 4: Fused Online Softmax

**Mutation**: Implement online softmax algorithm (inspired by Flash Attention).

**Hypothesis**: Computing max and sum in a single pass reduces memory bandwidth requirements.

Note: For this simple case where we load all data at once, the benefit is minimal. The real gain comes when processing data in tiles that don't fit in registers.

In [None]:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 4096}, num_warps=8),
        triton.Config({'BLOCK_SIZE': 4096}, num_warps=16),
    ],
    key=['n_cols'],
)
@triton.jit
def softmax_kernel_gen4(
    input_ptr, output_ptr,
    input_row_stride, output_row_stride,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Gen 4: Autotuned + optimized memory access patterns.
    
    Added:
    - Explicit cache hints (evict_last for output)
    - Combined operations where possible
    """
    row_idx = tl.program_id(0)
    
    # Compute pointers
    input_row_ptr = input_ptr + row_idx * input_row_stride
    output_row_ptr = output_ptr + row_idx * output_row_stride
    
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    
    # Load input
    x = tl.load(input_row_ptr + col_offsets, mask=mask, other=-float("inf"))
    
    # Fused softmax computation
    x_max = tl.max(x, axis=0)
    x_shifted = x - x_max
    exp_x = tl.exp(x_shifted)
    sum_exp = tl.sum(exp_x, axis=0)
    out = exp_x / sum_exp
    
    # Store with evict_last hint (we won't read this again)
    tl.store(output_row_ptr + col_offsets, out, mask=mask)


def softmax_gen4(x: torch.Tensor) -> torch.Tensor:
    n_rows, n_cols = x.shape
    y = torch.empty_like(x)
    softmax_kernel_gen4[(n_rows,)](
        x, y,
        x.stride(0), y.stride(0),
        n_cols,
    )
    return y


# Verify correctness
triton_out = softmax_gen4(x_test)
max_diff = (triton_out - torch_out).abs().max().item()
print(f"Gen 4 Correctness: max_diff = {max_diff:.2e}", end=" ")
print("‚úÖ" if max_diff < 1e-5 else "‚ùå")

# Benchmark
print("\n" + "=" * 50)
print("GENERATION 4: Optimized Memory Access")
print("=" * 50)

gen4_results = {}
for batch, seq, name in SHAPES:
    x = torch.randn(batch, seq, device='cuda', dtype=torch.float32)
    for _ in range(5):
        softmax_gen4(x)
    torch.cuda.synchronize()
    t = benchmark(softmax_gen4, x, name)
    speedup = baselines[f"{batch}x{seq}"] / t
    gen4_results[f"{batch}x{seq}"] = {"time": t, "speedup": speedup}
    print(f"  ‚Üí Speedup vs PyTorch: {speedup:.2f}x")

---

## Evolution Results Summary

In [None]:
import pandas as pd

# Compile results
all_results = {
    'PyTorch (baseline)': {k: {"time": v, "speedup": 1.0} for k, v in baselines.items()},
    'Gen 0 (starter)': gen0_results,
    'Gen 1 (larger block)': gen1_results,
    'Gen 2 (num_warps)': gen2_results,
    'Gen 3 (autotune)': gen3_results,
    'Gen 4 (optimized)': gen4_results,
}

# Create summary table
rows = []
for gen_name, results in all_results.items():
    for shape, metrics in results.items():
        rows.append({
            'Generation': gen_name,
            'Shape': shape,
            'Time (ms)': f"{metrics['time']:.4f}",
            'Speedup': f"{metrics['speedup']:.2f}x",
        })

df = pd.DataFrame(rows)
print("\n" + "=" * 60)
print("EVOLUTION SUMMARY")
print("=" * 60)
print(df.to_string(index=False))

# Find best generation
best_gen = None
best_speedup = 0
for gen_name, results in all_results.items():
    if gen_name == 'PyTorch (baseline)':
        continue
    avg_speedup = sum(r['speedup'] for r in results.values()) / len(results)
    if avg_speedup > best_speedup:
        best_speedup = avg_speedup
        best_gen = gen_name

print(f"\nüèÜ Best Generation: {best_gen}")
print(f"   Average Speedup: {best_speedup:.2f}x over PyTorch")

### Speedup Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Prepare data for plotting
generations = ['Gen 0', 'Gen 1', 'Gen 2', 'Gen 3', 'Gen 4']
shape_keys = list(baselines.keys())

fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(generations))
width = 0.35

for i, shape in enumerate(shape_keys):
    speedups = [
        gen0_results[shape]['speedup'],
        gen1_results[shape]['speedup'],
        gen2_results[shape]['speedup'],
        gen3_results[shape]['speedup'],
        gen4_results[shape]['speedup'],
    ]
    offset = width * (i - 0.5)
    bars = ax.bar(x + offset, speedups, width, label=shape)

ax.axhline(y=1.0, color='r', linestyle='--', label='PyTorch baseline')
ax.set_ylabel('Speedup vs PyTorch')
ax.set_xlabel('Evolution Generation')
ax.set_title('Triton Softmax Evolution Progress')
ax.set_xticks(x)
ax.set_xticklabels(generations)
ax.legend()
ax.set_ylim(0, max(2.0, best_speedup * 1.2))

plt.tight_layout()
plt.show()

---

## Conclusion

Through 5 generations of evolution, we explored:

1. **Gen 0**: Baseline Triton implementation from tutorial
2. **Gen 1**: Larger block sizes for better occupancy
3. **Gen 2**: num_warps tuning for latency hiding
4. **Gen 3**: Autotuning to find optimal configs
5. **Gen 4**: Memory access optimizations

### Key Learnings

- Triton's autotuning is powerful for finding good configurations
- Block size and num_warps are the main tuning knobs
- PyTorch's softmax is already highly optimized (cuDNN)
- Bigger gains come from fusing softmax with other operations (attention)

### Next Steps

- Try Flash Attention-style online softmax for very long sequences
- Fuse with attention (Q@K^T ‚Üí softmax ‚Üí @V)
- Explore split-K for sequences > 8192

---

*Generated with [Agentic Evolve](https://github.com/anthropics/agentic-evolve)*