# Triton Puzzle 2: Fused Softmax

Welcome to the second Triton puzzle! Now we'll tackle a more complex operation: fused softmax. This builds on what you learned in vector addition while introducing new concepts.

### What you'll learn:
- How to work with 2D data (matrices) in Triton
- Performing reduction operations (max, sum) in SRAM
- Why kernel fusion is crucial for performance
- Introduction to `num_warps` for better parallelism


## Mathematical Background

The softmax function transforms a vector of real numbers into a probability distribution. For a vector $\mathbf{x} = [x_1, x_2, ..., x_n]$:

$$\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$

However, this naive formulation is numerically unstable for large values. The stable version subtracts the maximum:

$$\text{softmax}(x)_i = \frac{e^{x_i - \max(\mathbf{x})}}{\sum_{j=1}^{n} e^{x_j - \max(\mathbf{x})}}$$

### Memory Operations Analysis

#### Naive Implementation (4 kernels):
1. Find max: Read $n$ → Write $1$
2. Subtract max: Read $n+1$ → Write $n$
3. Exponential: Read $n$ → Write $n$
4. Normalize: Read $n$ → sum, divide → Write $n$

Total: 5 reads + 3 writes per element = **8 memory operations per element**

#### Fused Implementation (1 kernel):
1. Load once, compute everything in SRAM, store once

Total: 1 read + 1 write per element = **2 memory operations per element**

**4x reduction in memory traffic!**

In [None]:
import torch
import triton
import triton.language as tl
import numpy as np
from IPython.display import Image, display

DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')
print(f"Using device: {DEVICE}")

## Implementation 1: Naive PyTorch (Unfused)

Let's start with a naive implementation that clearly shows each step:

In [None]:
def softmax_naive(x):
    """
    Naive softmax implementation - multiple kernel launches.
    Input: (M, N) matrix
    Output: (M, N) matrix with softmax applied row-wise
    """
    # Step 1: Find maximum per row (first kernel)
    x_max = x.max(dim=1, keepdim=True)[0]
    
    # Step 2: Subtract maximum for numerical stability (second kernel)
    x_shifted = x - x_max
    
    # Step 3: Exponentiate (third kernel)
    x_exp = torch.exp(x_shifted)
    
    # Step 4: Sum and normalize (fourth kernel)
    x_sum = x_exp.sum(dim=1, keepdim=True)
    
    return x_exp / x_sum

## Implementation 2: PyTorch Built-in

PyTorch's built-in softmax is already optimized and fused:

In [None]:
def softmax_pytorch(x):
    """PyTorch's built-in softmax - already fused."""
    return torch.softmax(x, dim=1)

## Implementation 3: PyTorch Compiled

Let's see if torch.compile can fuse our naive implementation:

In [None]:
@torch.compile
def softmax_compiled(x):
    """Compiled version of naive softmax."""
    x_max = x.max(dim=1, keepdim=True)[0]
    x_shifted = x - x_max
    x_exp = torch.exp(x_shifted)
    x_sum = x_exp.sum(dim=1, keepdim=True)
    return x_exp / x_sum

## Implementation 4: Triton Kernel (Puzzle)

Now implement fused softmax in Triton! Here are the key concepts for this puzzle:

### 1. Working with 2D Data
- Each program processes one or more rows
- We use `BLOCK_SIZE` to handle rows that don't fit in SRAM

### 2. Reductions in Triton
- `tl.max(x, axis=0)`: Find maximum along an axis
- `tl.sum(x, axis=0)`: Sum along an axis
- These operations happen entirely in fast SRAM!

### 3. Program Organization
- Each program handles one row of the matrix
- Programs run in parallel across different SMs (Streaming Multiprocessors)

### 4. Introduction to Warps
- A **warp** is a group of 32 threads that execute in lockstep
- `num_warps` controls how many warps are assigned to each program
- More warps can help with parallelism but use more resources


### Your Task:
Complete the kernel to:
1. Load a row of data (handling the case where row > BLOCK_SIZE)
2. Find the maximum value in the row
3. Compute exponentials with numerical stability
4. Sum the exponentials
5. Normalize and store the result

In [None]:
@triton.jit
def softmax_kernel(
    input_ptr,      # Pointer to input matrix
    output_ptr,     # Pointer to output matrix
    n_cols,         # Number of columns (row size)
    input_stride,   # Stride between rows in input
    output_stride,  # Stride between rows in output
    BLOCK_SIZE: tl.constexpr,  # Size of blocks to process
):
    """
    Fused softmax kernel. Each program computes one row.
    
    Key concepts:
    - Each program gets a unique row index via tl.program_id(0)
    - We process the row in chunks of BLOCK_SIZE
    - All operations (max, exp, sum) happen in SRAM
    """
    # Get the row index for this program
    row_idx = tl.program_id(0)
    
    # Calculate the starting pointer for this row
    input_row_start = input_ptr + row_idx * input_stride
    output_row_start = output_ptr + row_idx * output_stride
    
    # YOUR IMPLEMENTATION GOES HERE
    # Hints:
    # 1. Process the row in BLOCK_SIZE chunks (use tl.range)
    # 2. Use tl.max() to find the maximum
    # 3. Keep running sums for numerical stability
    # 4. Remember to mask for rows not divisible by BLOCK_SIZE
    pass


def softmax_triton(x):
    """Wrapper for the Triton softmax kernel."""
    # Ensure input is contiguous
    x = x.contiguous()
    
    # Get dimensions
    n_rows, n_cols = x.shape
    
    # Allocate output
    output = torch.empty_like(x)
    
    # BLOCK_SIZE must be power of 2 and >= n_cols
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    
    # Simple heuristic for num_warps:
    # - More warps for larger blocks
    # - Minimum 2, maximum 8
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    elif BLOCK_SIZE >= 1024:
        num_warps = 4
    else:
        num_warps = 2
    
    # Launch grid: one program per row
    grid = (n_rows,)
    
    # Launch kernel
    softmax_kernel[grid](
        x, output,
        n_cols,
        x.stride(0),  # Stride between rows
        output.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    
    return output

## Solution (Hidden)

<font size="6">🧙</font> You shall not pass! 

In [None]:
@triton.jit
def softmax_kernel(
    input_ptr,
    output_ptr,
    n_cols,
    input_stride,
    output_stride,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Fused softmax kernel. Each program computes one row.
    """
    # Get the row index for this program
    row_idx = tl.program_id(0)
    
    # Calculate the starting pointer for this row
    input_row_start = input_ptr + row_idx * input_stride
    output_row_start = output_ptr + row_idx * output_stride
    
    # Step 1: Find maximum value in the row
    # We process in chunks of BLOCK_SIZE
    row_max = float('-inf')
    for col_offset in range(0, n_cols, BLOCK_SIZE):
        # Create column indices for this block
        cols = col_offset + tl.arange(0, BLOCK_SIZE)
        
        # Mask to handle the last block if n_cols % BLOCK_SIZE != 0
        mask = cols < n_cols
        
        # Load values from HBM to SRAM
        vals = tl.load(input_row_start + cols, mask=mask, other=float('-inf'))
        
        # Update maximum
        row_max = tl.maximum(row_max, tl.max(vals, axis=0))
    
    # Step 2: Compute exp(x - max) and sum
    # We need another pass through the data
    exp_sum = 0.0
    for col_offset in range(0, n_cols, BLOCK_SIZE):
        cols = col_offset + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        
        # Load values again
        vals = tl.load(input_row_start + cols, mask=mask, other=float('-inf'))
        
        # Compute exp(x - max) for numerical stability
        exp_vals = tl.exp(vals - row_max)
        
        # Mask out invalid values
        exp_vals = tl.where(mask, exp_vals, 0.0)
        
        # Add to sum
        exp_sum += tl.sum(exp_vals, axis=0)
    
    # Step 3: Normalize and store
    for col_offset in range(0, n_cols, BLOCK_SIZE):
        cols = col_offset + tl.arange(0, BLOCK_SIZE)
        mask = cols < n_cols
        
        # Load values one more time
        vals = tl.load(input_row_start + cols, mask=mask, other=float('-inf'))
        
        # Compute final softmax values
        exp_vals = tl.exp(vals - row_max)
        softmax_vals = exp_vals / exp_sum
        
        # Store results back to HBM
        tl.store(output_row_start + cols, softmax_vals, mask=mask)


## FLOP and Memory Analysis

Let's analyze the computational complexity:

In [None]:
def analyze_softmax_ops(n_rows, n_cols):
    """Analyze operations for softmax computation."""
    
    # FLOPs per row:
    # - Finding max: n_cols comparisons
    # - Subtraction: n_cols ops
    # - Exp: n_cols ops (counted as multiple FLOPs)
    # - Sum: n_cols additions
    # - Division: n_cols ops
    
    exp_flops = 10  # Approximate FLOPs for exponential
    
    flops_per_row = (
        n_cols +           # max
        n_cols +           # subtract
        n_cols * exp_flops + # exp
        n_cols +           # sum
        n_cols             # divide
    )
    
    total_flops = n_rows * flops_per_row
    
    # Memory operations (in bytes)
    element_size = 4  # float32
    
    naive_mem_ops = {
        'reads': n_rows * n_cols * element_size * 4,   # Read 4 times
        'writes': n_rows * n_cols * element_size * 3,  # Write 3 times
        'total': n_rows * n_cols * element_size * 7
    }
    
    fused_mem_ops = {
        'reads': n_rows * n_cols * element_size * 1,   # Read once
        'writes': n_rows * n_cols * element_size * 1,  # Write once
        'total': n_rows * n_cols * element_size * 2
    }
    
    return {
        'flops': total_flops,
        'naive_memory': naive_mem_ops,
        'fused_memory': fused_mem_ops,
        'memory_reduction': naive_mem_ops['total'] / fused_mem_ops['total']
    }

# Example analysis
n_rows, n_cols = 1024, 2048
analysis = analyze_softmax_ops(n_rows, n_cols)

print(f"Softmax analysis for {n_rows}x{n_cols} matrix:")
print(f"  Total FLOPs: {analysis['flops']:,}")
print(f"  Naive memory traffic: {analysis['naive_memory']['total'] / 1e9:.2f} GB")
print(f"  Fused memory traffic: {analysis['fused_memory']['total'] / 1e9:.2f} GB")
print(f"  Memory reduction: {analysis['memory_reduction']:.1f}x")


## Testing Correctness

Verify our implementation matches PyTorch:

In [None]:
def test_correctness(n_rows=100, n_cols=2048, atol=1e-5, rtol=1e-5):
    """Test if Triton implementation matches PyTorch."""
    torch.manual_seed(42)
    x = torch.randn(n_rows, n_cols, device=DEVICE, dtype=torch.float32)
    
    # Compute with PyTorch
    expected = softmax_pytorch(x)
    
    # Compute with Triton
    actual = softmax_triton(x)
    
    try:
        torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
        print(f"✅ Test PASSED! Results match within tolerance.")
        print(f"   Shape tested: ({n_rows}, {n_cols})")
        print(f"   Max absolute difference: {(actual - expected).abs().max().item():.2e}")
        
        # Test numerical stability with large values
        x_large = torch.randn(10, 100, device=DEVICE) * 100
        expected_large = softmax_pytorch(x_large)
        actual_large = softmax_triton(x_large)
        torch.testing.assert_close(actual_large, expected_large, atol=atol, rtol=rtol)
        print(f"✅ Numerical stability test PASSED!")
        
        return True
    except AssertionError as e:
        print(f"❌ Test FAILED!")
        print(f"   Error: {e}")
        return False

# Run tests
test_passed = test_correctness()

# Display congrats message
if test_passed:
    print("\n🎉 Congratulations! Your implementation is correct!")
    display(Image("https://c.tenor.com/9d2wq28eb9UAAAAC/tenor.gif", width=256, height=256))

## Benchmarking

Let's benchmark all implementations using Triton's tools:

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['n_cols'],  # Column size as x-axis
        x_vals=[128 * i for i in range(1, 33)],  # 128 to 4096
        line_arg='provider',
        line_vals=['naive', 'pytorch', 'compiled', 'triton'],
        line_names=['Naive', 'PyTorch', 'Compiled', 'Triton'],
        styles=[('blue', '-'), ('green', '--'), ('orange', '-.'), ('red', ':')],
        ylabel='GB/s',
        plot_name='softmax-performance',
        args={'n_rows': 1024},  # Fix number of rows
    )
)
def benchmark(n_rows, n_cols, provider):
    """Benchmark softmax implementations."""
    x = torch.randn(n_rows, n_cols, device=DEVICE, dtype=torch.float32)
    
    quantiles = [0.5, 0.05, 0.95]
    
    if provider == 'naive':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_naive(x), quantiles=quantiles
        )
    elif provider == 'pytorch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_pytorch(x), quantiles=quantiles
        )
    elif provider == 'compiled':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_compiled(x), quantiles=quantiles
        )
    elif provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: softmax_triton(x), quantiles=quantiles
        )
    
    # GB/s calculation for fused: 2 * size * 4 bytes / time
    gbps = lambda ms: 2 * n_rows * n_cols * x.element_size() * 1e-9 / (ms * 1e-3)
    
    return gbps(ms), gbps(max_ms), gbps(min_ms)

# Run benchmarks
print("Running benchmarks...")
results = benchmark.run(print_data=True, return_df=True, save_path='')

## Speedup?

In [None]:
# Check if Triton is faster than PyTorch
avg_pytorch = results['PyTorch'].mean()
avg_triton = results['Triton'].mean()
speedup = avg_triton / avg_pytorch

if speedup > 1.0:
    print(f"\n🚀 Awesome! Triton is {speedup:.2f}x faster than PyTorch!")
    display(Image("https://c.tenor.com/QFFzqAIAvnIAAAAd/tenor.gif", width=400, height=256))
else:
    print(f"\n🐌🐌🐌 Triton implementation is {speedup:.2f}x slower than PyTorch!. 🐌🐌🐌")

## Summary

In this tutorial, you learned:

1. **Kernel Fusion**: Why combining operations is crucial for performance
2. **2D Operations**: How to work with matrices in Triton
3. **Reductions**: Using `tl.max()` and `tl.sum()` for row-wise operations
4. **num_warps**: Introduction to controlling thread parallelism
5. **Numerical Stability**: Implementing stable softmax computation

### Key Insights:

- **Memory Bandwidth**: Fused softmax uses 3.5x less memory bandwidth than naive
- **SRAM Utilization**: All intermediate values stay in fast SRAM
- **Block Processing**: Handling rows larger than SRAM capacity
- **Parallelism**: Each row is processed independently by different programs

### Performance Tips:

- Adjust `BLOCK_SIZE` based on your typical row sizes
- Experiment with `num_warps` for your specific GPU
- Consider `num_stages` (next puzzle!) for even better performance

### Next Steps:

Ready for matrix multiplication? The next puzzle introduces more advanced concepts like:
- 2D block tiling
- Shared memory optimization  
- `num_stages` for pipelining
- Auto-tuning for optimal performance


<img src="sardine-challenge.png" width="800" />