In [1]:
import torch

import triton
import triton.language as tl

DEVICE = torch.device('cuda:0')

@triton.jit
def _softmax_fwd_kernel(
    out_ptr,
    stride_out_row,
    x_ptr,
    stride_x_row,
    num_cols: tl.constexpr,
    block_size: tl.constexpr,
):
    row_index = tl.program_id(0)
    row_start_ptr = x_ptr + row_index * stride_x_row
    col_offsets = tl.arange(0, block_size)
    row_mask = col_offsets < num_cols

    # Move to SRAM
    row = tl.load(row_start_ptr + col_offsets, mask=row_mask, other=float('-inf'))

    # Softmax
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # Write back to HBM
    out_row_ptr = out_ptr + row_index * stride_out_row
    tl.store(out_row_ptr + col_offsets, sm_out, mask=row_mask)
    
def softmax(x: torch.Tensor) -> torch.Tensor:
    """Triton implementation of Softmax"""

    assert x.dim() == 2, "Only accepts 2D tensors"
    rows, cols = x.shape
    sm_out = torch.empty_like(x)
    
    BLOCK_SIZE = triton.next_power_of_2(cols)
    MAX_WARPS = 16
    num_warps = min(2 ** (2 + BLOCK_SIZE // 2048), MAX_WARPS)
    grid = (rows,)

    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=BLOCK_SIZE,
        num_warps=num_warps,
    )

    return sm_out

In [2]:
import torch.nn.functional as F

x = torch.Tensor([
    [1,2,3,4,5],
    [5,4,3,2,1],
    ])
x = x.to(dtype=torch.float32, device=DEVICE)

row0 = x[0]
max = torch.max(row0, dim=0)[0]
safe_row0 = row0 - max
sm_out = torch.exp(safe_row0)
sm_out /= torch.sum(sm_out, dim=0)
sm_out

tensor([0.0117, 0.0317, 0.0861, 0.2341, 0.6364], device='cuda:0')