In [None]:
from setup_triton import setup_triton

# TRITON_INTERPRET=1 uses a python interpreter instead of running on the GPU. 
# This menas that uou can insert Python breakpoints to debug your kernel code! 
setup_triton(use_interpreter=True)

# Triton Puzzle 3: Fused Entmax

Welcome to the third Triton puzzle! Now we'll tackle a more complex operation: fused entmax. 

### What you'll learn:
- How entmax can be computed via the bisection algorithm
- How it can be parallelized in Triton


## Mathematical Background

The $\alpha$-entmax mapping is given by:
$$
{\alpha\text{-entmax}(x)}_i = [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} 
$$

Thus, for one to calculate entmax, one requires to find $\tau$ that satistfies $\sum_i [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} = 1$, that is, that satisfies that the sum of all entries sums to one.

Let us define the function $f(\tau)$ for which we are trying to find the root for:
$$
f(\tau) = \sum_i [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} - 1
$$

Currently, in the entmax package (`pip install entmax`) the algorithm to calculate this $\tau$ for an arbitrary $\alpha$ makes use of the bisection algorithm. A high-level description of the bisection algorithm goes like this:

$$
\begin{align*}
\text{Input: }& x \in \mathbb{R}^n, \alpha \in \mathbb{R}, T \text{ iterations} \\
\text{(1): }& m \leftarrow \max(x) \\ 
\text{(2): }& \text{Initialize } \tau_\text{lo} = m - 1,\tau_\text{hi} = m - n^{1-\alpha}, \tau = \frac{\tau_\text{lo} + \tau_\text{hi}}{2} \\
\text{(3): }& \text{For t in T iterations do:} \\
&\quad\text{Compute } f(\tau) \\
&\quad\text{If } f(\tau) > 0: \tau_\text{lo} = \tau \text{ else } \tau_\text{hi} = \tau \\
&\quad\tau \leftarrow \frac{\tau_\text{lo} + \tau_\text{hi}}{2} \\
\text{(4): }& \text{Store element-wise }  [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} \\
\end{align*}
$$

In [None]:
!pip install entmax

In [None]:
import torch
import triton
import triton.language as tl
import numpy as np
from IPython.display import display, Image

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# Set random seed
torch.manual_seed(42)

### Complete the code below
Assume the following:
- The input is a matrix $b \times n$, and we want to perform the entmax transformation along the last dimension. 
- $n$ is a power of two, so no masking is required.

In [None]:
@triton.jit
def _ent_bisect(x_ptr, y_ptr, alpha, n_iter, N: tl.constexpr, TILE: tl.constexpr):
    
    # get row that this thread block will be responsible for
    curr_row = tl.program_id(0)

    # move pointers to the start of the input and output tensors
    x_ptr += curr_row * N
    y_ptr += curr_row * N

    # same as torch.arange
    offsets = tl.arange(0, TILE)

    # YOUR IMPLEMENTATION GOES HERE
    # 1) Calculate the maximum row-wise
    # 2) Run bisection `n_iter` times to find tau (be careful about possible NaNs!).
    # 3) Finally apply the entmax function with tau and store the result.
    pass


def entmax_triton(x, alpha=1.5, n_iter=50):
    rows, cols = x.shape
    assert cols.bit_count() == 1, "We require the number of columns to be a power of 2."
    TILE = 1024 if cols > 1024 else cols

    # launch with as many programs as rows in x
    grid = (rows,)

    # allocate output tensor
    y = torch.empty_like(x)

    # launch the kernel
    _ent_bisect[grid](x, y, alpha, n_iter, cols, TILE)

    return y

## Solution üßô

In [None]:
# Our solution goes here

## Testing Correctness

Verify our implementation matches PyTorch:

In [None]:
def test_correctness(n_rows=100, n_cols=2048, atol=1e-5, rtol=1e-5):
    """Test if Triton implementation matches PyTorch."""
    torch.manual_seed(42)

    b, n = 16, 4096 * 2
    alpha = 1.5
    n_iter = 50
    
    x = torch.randn((b, n), device=DEVICE, dtype=torch.float32).contiguous()
    
    # Compute with PyTorch
    expected = entmax_bisect(x, alpha=alpha, n_iter=n_iter)
    
    # Compute with Triton
    actual = entmax_triton(x, alpha=alpha, n_iter=n_iter)
    
    try:
        torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
        print(f"‚úÖ Test PASSED! Results match within tolerance.")
        print(f"   Shape tested: ({n_rows}, {n_cols})")
        print(f"   Max absolute difference: {(actual - expected).abs().max().item():.2e}")
        return True
    except AssertionError as e:
        print(f"‚ùå Test FAILED!")
        print(f"   Error: {e}")
        return False

# Run tests
test_passed = test_correctness()

# Display congrats message
if test_passed:
    print("\nüéâ Congratulations! Your implementation is correct!")
    display(Image("figs/success.gif", width=256, height=256))

## Summary

In this tutorial, you learned:

1. **Bisection Algorithm**: How to implement root-finding algorithms in parallel on GPUs
2. **Multi-pass Kernels**: Processing data multiple times to find global statistics
3. **Sparse Activations**: Understanding entmax as a sparse alternative to softmax
4. **Numerical Stability**: Using `exp2` and `log2` to handle power operations safely

### Key Insights:

- **Algorithm Parallelization**: Each row's $\tau$ can be computed independently
- **Memory Efficiency**: All bisection iterations stay in fast SRAM!
- **Tiling Strategy**: Processing large vectors in `BLOCK_SIZE` chunks
- **Sparse Outputs**: Unlike softmax, entmax can produce exact zeros. Can we leverage this?


### Performance Tips:

- Use enough bisection iterations (20-50) for convergence
- Choose `BLOCK_SIZE` based on your typical sequence lengths
- Pre-compute multiplications when possible


Next puzzle: matmul!

<img src="figs/sardine-challenge.png" width="512" />

---

## Benchmarking (GPU only)

Below you can check how well your solution performs against entmax's package.

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'], 
        x_vals=[4096, 8192, 16384],
        line_arg='provider', 
        line_vals=['triton', 'torch'],  
        line_names=['Triton', 'Torch'],
        styles=[('blue', '-'), ('green', '--')], 
        ylabel='Time (ms)', 
        plot_name='entmax-perf',
        args={},
    ))


def benchmark(size, provider):
    alpha = 1.5
    n_iter = 20
    x = torch.rand((2048, size), device="cuda", dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: entmax_bisect(x, n_iter=n_iter, dim=1), quantiles=quantiles, warmup=500, rep=1000)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: entmax_triton(x, alpha=alpha, n_iter=n_iter), quantiles=quantiles, warmup=500, rep=1000)
    return ms

In [None]:
benchmark.run(print_data=True, show_plots=True)