In [2]:
import torch
from torch import Tensor
from math import ceil
from torch.autograd import Function

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
import triton
import triton.language as tl

In [33]:
def naive_softmax(x:Tensor) -> Tensor:
    max = x.max()
    x -= x.max(dim = -1)[0][:, None] # this is done for numerical stability
    num = torch.exp(x)
    den = torch.sum(num, dim=1)[:, None]
    return num / den

# x = torch.rand(100, 200).to(DEVICE)
# triton.testing.assert_close(naive_softmax(x), torch.nn.functional.softmax(x, dim=1))

In [45]:
@triton.jit
def fused_softmax_kernel(x_pointer, y_pointer, x_stride, y_stride, n_rows, n_cols, block_size: tl.constexpr):
    # get the program id: each program of the grid handles one (or more) rows of the tensor
    pid = tl.program_id(axis=0)

    # strided execution: can run the program in a strided way (e.g. for row 0, 8, 16, ...)
    row_step = tl.num_programs(axis=0) # n. of programs running on given axis

    # loop through the rows executed by program with this pid
    for row_idx in tl.range(pid, n_rows, row_step):
        x_row_pointer = x_pointer + row_idx * x_stride

        col_offset = tl.arange(0, block_size)
        x_col_pointer = x_row_pointer + col_offset
        mask = n_cols < block_size

        # compute the softmax (with shift for numerical stab.)
        row = tl.load(x_col_pointer, mask, other=-float('inf'))

        max = tl.max(row, axis=0)
        row_shift = row - max

        num = tl.exp(row_shift)
        den = tl.sum(num, axis=0)
        y = num / den

        y_row_pointer = y_pointer + row_idx * y_stride
        y_col_offset = y_row_pointer + col_offset
        tl.store(y_col_offset, y, mask)


def fused_softmax_triton(x:Tensor, block_size:int=1024) -> Tensor:
    assert x.is_cuda

    n_rows, n_cols = x.shape
    y = torch.empty_like(x)
    grid = ceil(n_rows / block_size),
    BLOCK_SIZE = triton.next_power_of_2(n_cols)  # Used to tile the row

    fused_softmax_kernel[grid](x, y, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE)

    return y




In [48]:
x = torch.rand(8, 16).to(DEVICE)
print(naive_softmax(x))
print(fused_softmax_triton(x))
torch.allclose(naive_softmax(x), fused_softmax_triton(x))
print((naive_softmax(x) - fused_softmax_triton(x)))
print((naive_softmax(x)[0,0], fused_softmax_triton(x)[0,0]))

tensor([[0.0428, 0.0593, 0.0543, 0.0390, 0.0805, 0.0934, 0.0425, 0.0617, 0.0456,
         0.0625, 0.0457, 0.0873, 0.0536, 0.1016, 0.0796, 0.0507],
        [0.0373, 0.0662, 0.0768, 0.0648, 0.0401, 0.0830, 0.0595, 0.0485, 0.0743,
         0.0403, 0.0648, 0.0835, 0.0578, 0.0635, 0.0639, 0.0756],
        [0.0690, 0.0838, 0.0886, 0.0692, 0.0522, 0.0561, 0.0501, 0.0531, 0.0724,
         0.0529, 0.0493, 0.0611, 0.0696, 0.0464, 0.0630, 0.0634],
        [0.0358, 0.0417, 0.0766, 0.0613, 0.0903, 0.0744, 0.0598, 0.0830, 0.0815,
         0.0489, 0.0381, 0.0759, 0.0710, 0.0602, 0.0438, 0.0579],
        [0.0563, 0.0745, 0.0683, 0.0878, 0.0602, 0.0612, 0.0491, 0.0928, 0.0869,
         0.0430, 0.0741, 0.0445, 0.0470, 0.0560, 0.0427, 0.0558],
        [0.0591, 0.0365, 0.0557, 0.0573, 0.0466, 0.0552, 0.0730, 0.0765, 0.0816,
         0.0854, 0.0600, 0.0621, 0.0599, 0.0375, 0.0875, 0.0661],
        [0.0584, 0.0595, 0.0545, 0.0773, 0.0755, 0.0902, 0.0878, 0.0412, 0.0646,
         0.0822, 0.0712, 0.0410, 0.04