In [None]:
import torch
import triton
import triton.language as tl

# !pip install entmax
from entmax import entmax_bisect 

### Goal: Implement entmax bisection in Triton

The $\alpha$-entmax mapping is given by:
$$
{\alpha\text{-entmax}(x)}_i = [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} 
$$

Thus, for one to calculate entmax, one requires to find $\tau$ that satistfies $\sum_i [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} = 1$, that is, that satisfies that the sum of all entries sums to one.

Let us define the function $f(\tau)$ for which we are trying to find the root for:
$$
f(\tau) = \sum_i [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} - 1
$$

Currently, in the entmax package (pip install entmax) the algorithm to calculate this $\tau$ makes use of the bisection algorithm. Thus, a high-level description of the algorithm goes like this:

$$
\begin{align*}
\text{Input: }& x \in \mathbb{R}^n, \alpha \in \mathbb{R}, T \text{ iterations} \\
\text{(1): }& \max \leftarrow \max(x) \\ 
\text{(2): }& \text{Initialize } \tau_\text{lo} = \max - 1,\tau_\text{hi} = \max - n^{1-\alpha}, \tau = \frac{\tau_\text{lo} + \tau_\text{hi}}{2} \\
\text{(3): }& \text{For t in T iterations do:} \\
&\quad\text{Compute } f(\tau) \\
&\quad\text{If } f(\tau) > 0: \tau_\text{lo} = \tau \text{ else } \tau_\text{hi} = \tau \\
&\quad\tau \leftarrow \frac{\tau_\text{lo} + \tau_\text{hi}}{2} \\
\text{(4): }& \text{Store element-wise }  [(\alpha - 1)x_i - \tau]_+^{\frac{1}{\alpha-1}} \\
\end{align*}
$$

### Complete the code below. 
Assume the following:
- The input is a matrix $b \times n$, and we want to perform the entmax transformation along the last dimension. 
- $n$ is a power of two, so no masking is required.

In [None]:
@triton.jit
def _ent_bisect(x_ptr, y_ptr, alpha, n_iter, N: tl.constexpr, TILE: tl.constexpr):
    
    # YOUR IMPLEMENTATION GOES HERE
    # 1) Calculate the maximum.
    # 2) Run bisection `n_iter` times (be careful about possible NaNs!).
    # 3) Finally apply the entmax function and store the result.
    
    pass

def entmax_triton(x, alpha=1.5, n_iter=50):
    rows, cols = x.shape
    assert cols.bit_count() == 1, "We require the number of columns to be a power of 2."
    TILE = 1024 if cols > 1024 else cols

    # launch with as many programs as rows in x
    grid = (rows,)

    # allocate output tensor
    y = torch.empty_like(x)

    # launch the kernel
    _ent_bisect[grid](x, y, alpha, n_iter, cols, TILE)

    return y

In [None]:
b, n = 16, 4096 * 2
alpha = 1.5
n_iter = 50

x = torch.randn((b, n), device='cuda', dtype=torch.float32).contiguous()
y_ref = entmax_bisect(x, alpha=alpha, n_iter=50)
y_triton = entmax_triton(x, alpha=alpha, n_iter=n_iter)

print(f"Max error: {torch.max(torch.abs(y_ref - y_triton))}")

### ⬇️ Below you can check how well your solution performs against entmax's package.

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'], 
        x_vals=[4096, 8192, 16384],
        line_arg='provider', 
        line_vals=['triton', 'torch'],  
        line_names=['Triton', 'Torch'],
        styles=[('blue', '-'), ('green', '--')], 
        ylabel='Time (ms)', 
        plot_name='entmax-perf',
        args={},
    ))



def benchmark(size, provider):
    alpha = 1.5
    n_iter = 20
    x = torch.rand((2048, size), device="cuda", dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: entmax_bisect(x, n_iter=n_iter, dim=1), quantiles=quantiles, warmup=500, rep=1000)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: entmax_triton(x, alpha=alpha, n_iter=n_iter), quantiles=quantiles, warmup=500, rep=1000)
    return ms

In [None]:
benchmark.run(print_data=True, show_plots=True)

## ⚠️ Solution below!

In [None]:
@triton.jit
def alpha_entmax(x, tau, alpha):
    x = (alpha - 1) * x - tau
    # Here we have to mask out negative values
    # because we are using log2, which is not defined for negative values.
    x = tl.where(x > 0, tl.exp2(1 / (alpha - 1) * tl.log2(x)), 0.0)
    return x


@triton.jit
def _ent_bisect(x_ptr, y_ptr, alpha, n_iter, N: tl.constexpr, TILE: tl.constexpr):
    # get row that this thread block will be responsible for
    curr_row = tl.program_id(0)

    # move pointers to the start of the input and output tensors
    x_ptr += curr_row * N
    y_ptr += curr_row * N
    
    # same as torch.arange
    offsets = tl.arange(0, TILE)

    # placeholder for max value
    max_val = -1.0e3

    for idx in range(0, N, TILE):
        # compute pointers for the current tile
        x_ptrs = (x_ptr + idx) + offsets

        # load TILE elements of X
        x = tl.load(x_ptrs)

        # update max value
        max_val = tl.maximum(max_val, tl.max(x))

    max_val *= (alpha - 1.0)

    # initialize tau bounds
    tau_lower = max_val - 1.0
    tau_upper = max_val - tl.exp2((1-alpha) * tl.log2(1.0*N))
    tau = (tau_lower + tau_upper) / 2.0
    
    # bisection
    for _ in range(n_iter):
        f_tau = -1.0

        for idx in range(0, N, TILE):
            # compute pointers for the current tile
            x_ptrs = (x_ptr + idx) + offsets

            # load TILE elements of X
            x = tl.load(x_ptrs)

            # accumulate f(tau)
            f_tau += tl.sum(alpha_entmax(x, tau, alpha))

        # update tau bounds
        if f_tau > 0:
            tau_lower = tau
        else:
            tau_upper = tau
        tau = (tau_lower + tau_upper) / 2.0


    for idx in range(0, N, TILE):
            # compute pointers for the current tile
            x_ptrs = (x_ptr + idx) + offsets
            y_ptrs = (y_ptr + idx) + offsets

            # load TILE elements of X
            x = tl.load(x_ptrs)

            # compute entmax for this TILE
            y = alpha_entmax(x, tau, alpha)

            # store results
            tl.store(y_ptrs, y)

def entmax_triton(x, alpha=1.5, n_iter=50):
    rows, cols = x.shape
    assert cols.bit_count() == 1, "We require the number of columns to be a power of 2."
    TILE = 1024 if cols > 1024 else cols

    # launch with as many blocks as rows in x
    grid = (rows,)

    # allocate output tensor
    y = torch.empty_like(x)

    # launch the kernel
    _ent_bisect[grid](x, y, alpha, n_iter, cols, TILE)

    return y