In [None]:
import triton
import triton.language as tl
import torch
from einops import rearrange


# Weighted Sum Kernel

In [None]:
@triton.jit
def weighted_sum_fwd(
    x_ptr, weight_ptr,            # Input pointers
    output_ptr,                   # Output pointer
    x_stride_row, x_stride_dim,   # Strides for x tensor
    weight_stride_dim,           # Likely 1
    output_stride_row,           # Likely 1
    ROWS, D,
    ROWS_TILE_SIZE: tl.constexpr, D_TILE_SIZE: tl.constexpr,  # Tile shapes
):
    # Compute the index of the row tile this program instance will handle
    row_tile_idx = tl.program_id(0)

    # Create block pointers for inputs and output
    x_block_ptr = tl.make_block_ptr(
        x_ptr,
        shape=(ROWS, D),
        strides=(x_stride_row, x_stride_dim),
        offsets=(row_tile_idx * ROWS_TILE_SIZE, 0),
        block_shape=(ROWS_TILE_SIZE, D_TILE_SIZE),
        order=(1, 0),
    )

    weight_block_ptr = tl.make_block_ptr(
        weight_ptr,
        shape=(D,),
        strides=(weight_stride_dim,),
        offsets=(0,),
        block_shape=(D_TILE_SIZE,),
        order=(0,),
    )

    output_block_ptr = tl.make_block_ptr(
        output_ptr,
        shape=(ROWS,),
        strides=(output_stride_row,),
        offsets=(row_tile_idx * ROWS_TILE_SIZE,),
        block_shape=(ROWS_TILE_SIZE,),
        order=(0,),
    )

    # Initialize output buffer
    output = tl.zeros((ROWS_TILE_SIZE,), dtype=tl.float32)

    for i in range(tl.cdiv(D, D_TILE_SIZE)):
        row = tl.load(x_block_ptr, boundary_check=(0, 1), padding_option="zero")
        weight = tl.load(weight_block_ptr, boundary_check=(0,), padding_option="zero")
        output += tl.sum(row * weight[None, :], axis=1)

        # Advance block pointers
        x_block_ptr = x_block_ptr.advance((0, D_TILE_SIZE))
        weight_block_ptr = weight_block_ptr.advance((D_TILE_SIZE,))

    # Store the result
    tl.store(output_block_ptr, output, boundary_check=(0,))


In [None]:
class WeightedSumFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight):
        D, output_dims = x.shape[-1], x.shape[:-1]

        # Reshape input tensor to 2D
        input_shape = x.shape
        x = rearrange(x, "... d -> (...) d")

        ctx.save_for_backward(x, weight)

        assert len(weight.shape) == 1 and weight.shape[0] == D, "Dimension mismatch"
        assert x.is_cuda and weight.is_cuda, "Expected CUDA tensors"
        assert x.is_contiguous(), "Our pointer arithmetic will assume contiguous x"

        ctx.D_TILE_SIZE = 32# triton.next_power_of_2(D) // 16 # Roughly 16 loops through the embedding dimension
        print(ctx.D_TILE_SIZE)
        ctx.ROWS_TILE_SIZE = 16 # Each thread processes 16 batch elements at a time
        ctx.input_shape = input_shape

        # Need to initialize empty result tensor. Note that these elements are not necessarily 0!
        y = torch.empty(output_dims, device=x.device)

        # Launch our kernel with n instances in our 1D grid.
        n_rows = y.numel()
        weighted_sum_fwd[(triton.cdiv(n_rows, ctx.ROWS_TILE_SIZE),)](
            x, weight,
            y,
            x.stride(0), x.stride(1),
            weight.stride(0),
            y.stride(0),
            ROWS=n_rows, D=D,
            ROWS_TILE_SIZE=ctx.ROWS_TILE_SIZE, D_TILE_SIZE=ctx.D_TILE_SIZE,
        )

        return y.view(input_shape[:-1])

In [None]:
def weighted_sum(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
    return WeightedSumFunc.apply(x, weight)

In [None]:

x = torch.randn(2, 3, 4, device="cuda", requires_grad=False)  # batch_size=2, seq_len=3, dim=4
weight = torch.randn(4, device="cuda", requires_grad=False)  # dim=4

y = weighted_sum(x, weight)  # shape -> (2, 3)
print("Output:", y)

# Можно делать backward
loss = y.sum()
print(loss)

# Sum two digits kernel

In [None]:
@triton.jit
def sum_kernel(
    arr1_ptr,
    arr2_ptr,
    output_ptr,
    N, BLOCK_SIZE :tl.constexpr
):
    pid = tl.program_id(axis=0)

    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    arr1 = tl.load(arr1_ptr + offsets, mask=mask)
    arr2 = tl.load(arr2_ptr + offsets, mask=mask)

    output = arr1 + arr2
    tl.store(output_ptr + offsets, output, mask=mask)

In [None]:
def add(arr1: torch.Tensor, arr2: torch.Tensor):
    output = torch.empty_like(arr1)
    N = output.numel()
        
    grid = lambda meta: (triton.cdiv(N, 1024),)
    
    sum_kernel[grid](arr1.contiguous(), arr2.contiguous(), output, N, BLOCK_SIZE=1024)
    return output

# ReLU Kernel

In [None]:
@triton.jit
def relu_kernel(
    input_ptr, output_ptr,
    N, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    mask = offsets < N

    input_t = tl.load(input_ptr + offsets, mask=mask)
    # output = tl.maximum(input_t, 0)
    output = tl.where(input_t>0, input_t, 0)

    tl.store(output_ptr + offsets, output, mask=mask)

In [None]:
def relu(x: torch.Tensor):
    output = torch.empty_like(x)
    N = output.numel()

    grid = lambda meta: (triton.cdiv(N, 1024),)
    relu_kernel[grid](x.contiguous(), output, N, BLOCK_SIZE=1024)
    return output

In [None]:
def test_relu():
    test_cases = [
        torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda"),            
        torch.tensor([-1.0, -2.0, -3.0], dtype=torch.float32, device="cuda"),                
        torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float32, device="cuda"),                 
        torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device="cuda"),                    
        torch.randn(1024, dtype=torch.float32, device="cuda"),                                
    ]

    for i, x in enumerate(test_cases):
        expected = F.relu(x)
        result = relu(x)
        if not torch.allclose(result, expected, atol=1e-6):
            print(f"Test case {i} failed")
            print("Input:    ", x)
            print("Expected: ", expected)
            print("Result:   ", result)
        else:
            print(f"Test case {i} passed")

test_relu()

# Comparision Kernel

In [None]:
@triton.jit
def comparison_kernel(
    a_ptr, b_ptr, output_ptr, N, BLOCK_SIZE 
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    a = tl.load(a_ptr + offsets, mask)
    b = tl.load(b_ptr + offsets, mask)

    output = a == b
    tl.store(output_ptr + offsets, output, mask=mask)

In [None]:
def comparison(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x, dtype=torch.int)
    N = output.numel()

    grid = lambda meta: (triton.cdiv(N, 1024),)
    comparison_kernel[grid](x.contiguous(), y.contiguous(), output, N, BLOCK_SIZE=1024)
    return output

# Softmax 1-d Kernel

In [None]:
@triton.jit
def softmax_kernel(
    x_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N

    x = tl.load(x_ptr + offsets, mask)
    x = tl.where(mask, x, -float("inf"))

    max_x = tl.max(x, axis=0)
    exp_x = tl.exp(x - max_x)
    sum_exp = tl.sum(exp_x, axis=0)
    output = exp_x / sum_exp
    tl.store(output_ptr + offsets, output, mask=mask)

In [None]:
def soft(x: torch.Tensor):
    output = torch.empty_like(x)
    N = output.numel()

    grid = lambda meta: (triton.cdiv(N, 1024),)
    softmax_kernel[grid](x.contiguous(), output, N, BLOCK_SIZE=1024)
    return output

In [None]:
def test_softmax():
    x = torch.tensor([1.0, 2.0, 3.0, 4.0], device='cuda')
    print("Input:", x)
    
    output = soft(x)
    print("Softmax Output:", output)

    expected = torch.softmax(x, dim=0)
    print("Expected:", expected)
    print("Test 1 passed:", torch.allclose(output, expected, atol=1e-5))

    x2 = torch.randn(1024, device='cuda')
    output2 = soft(x2)
    expected2 = torch.softmax(x2, dim=0)
    print("Test 2 passed:", torch.allclose(output2, expected2, atol=1e-5))

    x3 = torch.randn(4096, device='cuda')
    output3 = soft(x3)
    expected3 = torch.softmax(x3, dim=0)
    print("Test 3 passed:", torch.allclose(output3, expected3, atol=1e-5))

test_softmax()

# Softmax n-d Kernel

In [None]:
@triton.jit
def softmax_nd_kernel(
    input_ptr, output_ptr,
    input_row_stride, input_col_stride,
    output_row_stride, output_col_stride,
    N, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)

    row_start_ptr = input_ptr + pid * input_row_stride

    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < N

    input_ptrs = row_start_ptr + col_offsets * input_col_stride
    x = tl.load(input_ptrs, mask=mask, other=-float('inf'))

    max_x = tl.max(x, axis=0)
    x_exp = tl.exp(x - max_x)
    sum_exp = tl.sum(x_exp, axis=0)
    softmax_output = x_exp / sum_exp

    output_row_start_ptr = output_ptr + pid * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets * output_col_stride
    tl.store(output_ptrs, softmax_output, mask=mask)


In [None]:
def soft_nd(x: torch.Tensor, BLOCK_SIZE=1024):
    original_shape = x.shape
    N = x.shape[-1]
    num_rows = x.numel() // N  

    x_2d = rearrange(x, "... d -> (...) d")
    output = torch.empty_like(x_2d)

    input_stride_row = x_2d.stride(0) 
    input_stride_col = x_2d.stride(1) 

    output_stride_row = output.stride(0)
    output_stride_col = output.stride(1)

    # grid = (triton.cdiv(num_rows, BLOCK_SIZE), )
    grid = ((num_rows + BLOCK_SIZE - 1) // BLOCK_SIZE, )

    softmax_nd_kernel[grid](
        x_2d,
        output,
        input_stride_row,
        input_stride_col,
        output_stride_row,
        output_stride_col,
        N,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output.view(*original_shape)


In [None]:
def test_soft_nd():
    atol = 1e-5

    # 1D
    x1 = torch.tensor([1.0, 2.0, 3.0], device='cuda')
    out1 = soft_nd(x1)
    expected1 = torch.softmax(x1, dim=-1)
    print("1D test:", torch.allclose(out1, expected1, atol=atol))

    # 2D
    x2 = torch.randn(4, 8, device='cuda')
    out2 = soft_nd(x2)
    expected2 = torch.softmax(x2, dim=-1)
    print("2D test:", torch.allclose(out2, expected2, atol=atol))
    print(out2, expected2)


    # 3D
    x3 = torch.randn(3, 5, 8, device='cuda')
    out3 = soft_nd(x3)
    expected3 = torch.softmax(x3, dim=-1)
    print("3D test:", torch.allclose(out3, expected3, atol=atol))

    # 4D
    x4 = torch.randn(2, 3, 4, 8, device='cuda')
    out4 = soft_nd(x4)
    expected4 = torch.softmax(x4, dim=-1)
    print("4D test:", torch.allclose(out4, expected4, atol=atol))

    # Batched edge case
    x5 = torch.randn(10, 1024, device='cuda')
    out5 = soft_nd(x5)
    expected5 = torch.softmax(x5, dim=-1)
    print("Large 2D test:", torch.allclose(out5, expected5, atol=atol))

test_soft_nd()


# Online Softmax Kernel

In [None]:
@triton.jit
def online_softmax_kernel(
    input_ptr, output_ptr,
    input_row_stride, input_col_stride,
    output_row_stride, output_col_stride,
    N, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    col_offsets = tl.arange(0, BLOCK_SIZE)

    input_start_ptr = input_ptr + pid * input_row_stride

    mask = col_offsets < N

    M_old = -float('inf')
    L_old = 0.0

    num_blocks = (N + BLOCK_SIZE - 1) // BLOCK_SIZE

    for block_id in range(num_blocks):
        col_ids = block_id * BLOCK_SIZE + col_offsets
        mask = col_ids < N

        input_block_ptr = input_start_ptr + col_ids * input_col_stride
        x_block = tl.load(input_block_ptr, mask=mask, other=-float('inf'))
        local_max = tl.max(x_block)

        M_new = tl.maximum(local_max, M_old)
        scale = tl.exp(M_old - M_new)
        exp_diff = tl.exp(x_block - M_new)
        L_new = tl.where(M_new > M_old,
                        L_old * scale + tl.sum(exp_diff),
                        L_old + tl.sum(exp_diff))
        
        M_old = M_new
        L_old = L_new
        
    for block_id in range(num_blocks):
        col_ids = block_id * BLOCK_SIZE + col_offsets
        mask = col_ids < N

        input_block_ptr = input_start_ptr + col_ids * input_col_stride
        x_block = tl.load(input_block_ptr, mask=mask, other=-float('inf'))

        shifted = x_block - M_old
        exp_shifted = tl.exp(shifted)
        softmax_block = exp_shifted / L_old
        output_block_ptr = output_ptr + pid * output_row_stride + col_ids * output_col_stride
        tl.store(output_block_ptr, softmax_block, mask=mask)

In [None]:
def online_softmax_pt(input: torch.Tensor, block_size=1024):

    assert input.dim() == 2, "input tensor should be 2D"
    assert input.device.type == "cuda", "input tensor should be on CUDA"

    input = input.contiguous()
    B, N = input.size()
    # original_shape = input.shape

    output = torch.empty_like(input)

    grid = (B,)

    online_softmax_kernel[grid](
        input,
        output,
        input.stride(0), input.stride(1),
        output.stride(0), output.stride(1),
        N,
        # BLOCK_SIZE=block_size,
    )

    return output

In [None]:
B, N = 4, 128
x = torch.randn(B, N, device='cuda', dtype=torch.float32)

# Triton softmax
out_triton = online_softmax_pt(x)

# PyTorch softmax
out_torch = F.softmax(x, dim=-1)

max_diff = (out_triton - out_torch).abs().max().item()
print(f"Max difference: {max_diff:.6e}")