In [2]:
import torch
from torch import Tensor
from math import ceil
from torch.autograd import Function

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
import triton
import triton.language as tl

In [33]:
def naive_softmax(x:Tensor) -> Tensor:
    max = x.max()
    x -= x.max(dim = -1)[0][:, None] # this is done for numerical stability
    num = torch.exp(x)
    den = torch.sum(num, dim=1)[:, None]
    return num / den

# x = torch.rand(100, 200).to(DEVICE)
# triton.testing.assert_close(naive_softmax(x), torch.nn.functional.softmax(x, dim=1))

In [45]:
@triton.jit
def fused_softmax_kernel(x_pointer, y_pointer, x_stride, y_stride, n_rows, n_cols, block_size: tl.constexpr):
    # get the program id: each program of the grid handles one (or more) rows of the tensor
    pid = tl.program_id(axis=0)

    # strided execution: can run the program in a strided way (e.g. for row 0, 8, 16, ...)
    row_step = tl.num_programs(axis=0) # n. of programs running on given axis

    # loop through the rows executed by program with this pid
    for row_idx in tl.range(pid, n_rows, row_step):
        x_row_pointer = x_pointer + row_idx * x_stride

        col_offset = tl.arange(0, block_size)
        x_col_pointer = x_row_pointer + col_offset
        mask = n_cols < block_size

        # compute the softmax (with shift for numerical stab.)
        row = tl.load(x_col_pointer, mask, other=-float('inf'))

        max = tl.max(row, axis=0)
        row_shift = row - max

        num = tl.exp(row_shift)
        den = tl.sum(num, axis=0)
        y = num / den

        y_row_pointer = y_pointer + row_idx * y_stride
        y_col_offset = y_row_pointer + col_offset
        tl.store(y_col_offset, y, mask)


def fused_softmax_triton(x:Tensor, block_size:int=1024) -> Tensor:
    assert x.is_cuda

    n_rows, n_cols = x.shape
    y = torch.empty_like(x)
    grid = ceil(n_rows / block_size),
    BLOCK_SIZE = triton.next_power_of_2(n_cols)  # Used to tile the row

    fused_softmax_kernel[grid](x, y, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE)

    return y




In [49]:
x = torch.rand(8, 16).to(DEVICE)
# print(naive_softmax(x))
# print(fused_softmax_triton(x))
# torch.allclose(naive_softmax(x), fused_softmax_triton(x))
# print((naive_softmax(x) - fused_softmax_triton(x)))
print((naive_softmax(x)[0,0], fused_softmax_triton(x)[0,0]))

(tensor(0.0942, device='cuda:0'), tensor(0.9908, device='cuda:0'))


In [52]:
print(naive_softmax(x.clone()))
print(fused_softmax_triton(x.clone()))

tensor([[0.0942, 0.0576, 0.0905, 0.0420, 0.0682, 0.0375, 0.0534, 0.0587, 0.0899,
         0.0518, 0.0820, 0.0586, 0.0553, 0.0363, 0.0866, 0.0375],
        [0.0777, 0.0814, 0.0803, 0.0514, 0.0616, 0.0827, 0.0537, 0.0421, 0.0649,
         0.0782, 0.0569, 0.0477, 0.0480, 0.0680, 0.0559, 0.0495],
        [0.0383, 0.0878, 0.0644, 0.0642, 0.0370, 0.0480, 0.0630, 0.0936, 0.0647,
         0.0486, 0.0428, 0.0714, 0.0444, 0.0764, 0.0898, 0.0657],
        [0.0700, 0.0701, 0.0549, 0.0801, 0.0880, 0.0803, 0.0426, 0.0735, 0.0394,
         0.0735, 0.0372, 0.0427, 0.0469, 0.0712, 0.0503, 0.0791],
        [0.0453, 0.0886, 0.0429, 0.0579, 0.0941, 0.0549, 0.0694, 0.0471, 0.0582,
         0.0747, 0.0542, 0.0631, 0.0714, 0.0925, 0.0494, 0.0363],
        [0.0411, 0.0372, 0.0799, 0.0953, 0.0631, 0.0579, 0.0665, 0.0808, 0.0661,
         0.0374, 0.0369, 0.0571, 0.0450, 0.0654, 0.0804, 0.0899],
        [0.0931, 0.0446, 0.0550, 0.0430, 0.0585, 0.0821, 0.0819, 0.0491, 0.0851,
         0.0445, 0.0677, 0.0673, 0.05