# Triton Softmax Implementation

**FreeCodeCamp CUDA Course - Module 8: Triton**

Original Course: [https://www.youtube.com/watch?v=86FAWCzIe_4](https://www.youtube.com/watch?v=86FAWCzIe_4)
Source File: `02_softmax.py`

---

## Overview

Implement the softmax operation in Triton. Softmax is a fundamental operation in deep learning, used extensively in attention mechanisms and classification layers.

---

## Learning Objectives

By the end of this notebook, you will:

1. Understand the softmax operation and numerical stability
2. Implement row-wise operations in Triton
3. Use Triton's reduction operations (`tl.max`, `tl.sum`)
4. Handle 2D tensor operations with proper stride calculations
5. Compare Triton implementation with PyTorch's optimized softmax

---

## Setup

In [None]:
# Check GPU and install Triton
!nvidia-smi
!pip install triton -q

---

## Softmax Operation

### Mathematical Definition

For an input vector $x = [x_1, x_2, ..., x_n]$:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$

### Numerical Stability

Direct computation can cause overflow with large values. Use the **max subtraction trick**:

$$\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^{n} e^{x_j - \max(x)}}$$

### Properties
- Output values are in range (0, 1)
- Sum of outputs equals 1
- Preserves ordering (larger inputs → larger outputs)

---

## Triton Softmax Kernel

In [None]:
import torch
import triton
import triton.language as tl

@triton.jit
def softmax_kernel(
    output_ptr, 
    input_ptr, 
    input_row_stride, 
    output_row_stride, 
    n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Compute softmax along rows of a 2D tensor.
    Each program processes one row.
    
    Args:
        output_ptr: Pointer to output tensor
        input_ptr: Pointer to input tensor
        input_row_stride: Number of elements to skip between rows in input
        output_row_stride: Number of elements to skip between rows in output
        n_cols: Number of columns (elements per row)
        BLOCK_SIZE: Block size (must be >= n_cols)
    """
    # Each program handles one row
    row_idx = tl.program_id(axis=0)

    # Compute starting pointers for this row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    out_row_start_ptr = output_ptr + row_idx * output_row_stride

    # Create column offsets and mask
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # Load the entire row into SRAM
    # Use -inf for masked elements (won't affect max/exp)
    row = tl.load(row_start_ptr + col_offsets, mask=mask, other=-float('inf'))

    # Step 1: Find maximum value (for numerical stability)
    row_max = tl.max(row, axis=0)
    
    # Step 2: Subtract max and exponentiate
    numerator = tl.exp(row - row_max)
    
    # Step 3: Compute sum for normalization
    denominator = tl.sum(numerator, axis=0)
    
    # Step 4: Normalize to get softmax
    softmax_output = numerator / denominator
    
    # Store the result
    tl.store(out_row_start_ptr + col_offsets, softmax_output, mask=mask)


def triton_softmax(x):
    """Wrapper function to launch Triton softmax kernel."""
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    
    # Determine block size (round up to next power of 2, max 1024)
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    BLOCK_SIZE = min(BLOCK_SIZE, 1024)
    
    # Launch one program per row
    grid = (n_rows,)
    softmax_kernel[grid](
        output, x,
        x.stride(0), output.stride(0),
        n_cols, 
        BLOCK_SIZE=BLOCK_SIZE
    )
    return output


# Test the implementation
torch.manual_seed(0)
x = torch.randn(256, 1024, device='cuda')

# Compute softmax using PyTorch
torch_result = torch.softmax(x, dim=1)

# Compute softmax using Triton
triton_result = triton_softmax(x)

# Compare results
max_diff = torch.max(torch.abs(torch_result - triton_result))
print(f"Maximum difference between PyTorch and Triton results: {max_diff:.2e}")

# Check if results are close
is_close = torch.allclose(torch_result, triton_result, rtol=1e-5, atol=1e-5)
print(f"Results are close: {is_close}")

# Verify softmax properties
print(f"\nSoftmax properties:")
print(f"All values in (0,1): {torch.all((triton_result > 0) & (triton_result < 1))}")
print(f"Row sums equal 1: {torch.allclose(triton_result.sum(dim=1), torch.ones(256, device='cuda'))}")

---

## Understanding the Implementation

### 1. Row-Wise Processing
```python
row_idx = tl.program_id(axis=0)  # Each program = one row
```

### 2. Stride Calculations
For a 2D tensor with shape (256, 1024):
- `input_row_stride = 1024` (elements between rows)
- Row 0 starts at offset 0
- Row 1 starts at offset 1024
- Row i starts at offset `i * 1024`

### 3. Memory Loading
```python
row = tl.load(row_start_ptr + col_offsets, mask=mask, other=-float('inf'))
```
- Loads entire row into SRAM (fast on-chip memory)
- Masked elements set to -∞ (safe for max and exp)

### 4. Reductions
```python
row_max = tl.max(row, axis=0)      # Find maximum
denominator = tl.sum(numerator, axis=0)  # Sum for normalization
```

---

## Performance Comparison

In [None]:
import time

def benchmark_softmax(shape=(1024, 1024), num_runs=100):
    """Benchmark Triton vs PyTorch softmax."""
    x = torch.randn(shape, device='cuda')
    
    # Warmup
    for _ in range(10):
        _ = torch.softmax(x, dim=1)
        _ = triton_softmax(x)
    torch.cuda.synchronize()
    
    # Benchmark PyTorch
    start = time.time()
    for _ in range(num_runs):
        _ = torch.softmax(x, dim=1)
    torch.cuda.synchronize()
    pytorch_time = (time.time() - start) / num_runs * 1000
    
    # Benchmark Triton
    start = time.time()
    for _ in range(num_runs):
        _ = triton_softmax(x)
    torch.cuda.synchronize()
    triton_time = (time.time() - start) / num_runs * 1000
    
    print(f"Shape: {shape}")
    print(f"PyTorch: {pytorch_time:.3f} ms")
    print(f"Triton:  {triton_time:.3f} ms")
    print(f"Speedup: {pytorch_time / triton_time:.2f}x\n")

# Benchmark different shapes
benchmark_softmax((256, 1024))
benchmark_softmax((1024, 1024))
benchmark_softmax((4096, 1024))

---

## Numerical Stability Demo

Let's see why the max subtraction trick is important:

In [None]:
# Create a tensor with large values that would overflow
x_large = torch.tensor([[1000.0, 1001.0, 1002.0]], device='cuda')

print("Testing numerical stability:")
print(f"Input: {x_large}")

# Naive approach (would overflow)
print(f"\nNaive exp (overflows): {torch.exp(x_large)}")

# With max subtraction (stable)
x_shifted = x_large - x_large.max()
print(f"After max subtraction: {x_shifted}")
print(f"Stable exp: {torch.exp(x_shifted)}")

# Our Triton implementation handles this
result = triton_softmax(x_large)
print(f"\nTriton softmax result: {result}")
print(f"Sum: {result.sum():.6f}")

---

## Exercises

1. **Column-wise Softmax**: Modify the kernel to compute softmax along columns (dim=0)
   - Hint: Change how you organize programs and memory access

2. **Temperature Scaling**: Add a temperature parameter: `softmax(x/T)`
   - Higher temperature → more uniform distribution
   - Lower temperature → more peaky distribution

3. **Log-Softmax**: Implement log-softmax (more numerically stable)
   - `log_softmax(x) = log(softmax(x)) = x - log(sum(exp(x)))`

4. **Fused Softmax + Scale**: Compute `softmax(x) * scale` in one kernel
   - Compare performance with separate operations

5. **Different Shapes**: Test with various tensor shapes
   - Very wide: (16, 8192)
   - Very tall: (8192, 128)
   - Square: (2048, 2048)

---

## Key Takeaways

1. **Numerical stability is crucial** - Always use max subtraction trick
2. **Row-wise operations are natural in Triton** - One program per row
3. **Reduction operations** (`tl.max`, `tl.sum`) are efficient
4. **SRAM usage** - Loading entire rows enables fast reductions
5. **Triton matches PyTorch performance** for well-optimized kernels

---

## Next Steps

Continue to **Module 9** to learn how to integrate custom CUDA operations with PyTorch!

---

## Notes

*Use this space for your learning notes:*


