In [1]:
import torch

import triton
import triton.language as tl
import matplotlib
import pandas as pd
import numpy as np


In [2]:
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
    # x = x + 1
    return tl.where(x >= 0, x, 0.01 * x)


In [3]:
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=[
        # triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 64, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
        # triton.Config({'BLOCK_SIZE_B': 16, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 16, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 16, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
        # triton.Config({'BLOCK_SIZE_B': 16, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
        triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_E': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_E': 32}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_E': 64}, num_stages=2, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_E': 64}, num_stages=2, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_E': 64}, num_stages=2, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_E': 128}, num_stages=2, num_warps=4),
        # triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_E': 128}, num_stages=2, num_warps=4),
    ],
    key=['H', 'B', 'D', 'E'],
)
@triton.jit
def mlp_wide_kernel(
    # Pointers to matrices
    x_ptr, w1_ptr, w2_ptr, o_ptr,
    # Matrix dimensions
    H, B, D: tl.constexpr, E,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_xh, stride_xb, stride_xd,
    stride_w1h, stride_w1d, stride_w1e,
    stride_w2h, stride_w2e, stride_w2d,
    stride_oh, stride_ob, stride_od,
    # Meta-parameters
    BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_E: tl.constexpr,
    ACTIVATION: tl.constexpr,
):
    """Kernel for computing the mlp
    Z = X @ W1, H = f(Z), O = H @ W2.
    - X has shape (B * H, D)
    - W1 has shape (H, D, E)
    - W2 has shape (H, E, D)
    - O has shape (B * H, D)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # tl.static_print(stride_xh, stride_xb, stride_xd)
    pid = tl.program_id(0)
    pid_b = pid % B
    pid_h = pid // B
    TARGET_TYPE = x_ptr.type.element_ty
    x_ptrs = tl.make_block_ptr(
        base=x_ptr,
        shape=(H * B, D),
        strides=(stride_ob, stride_od),
        offsets=(pid_h * B + pid_b * BLOCK_SIZE_B, 0),
        block_shape=(BLOCK_SIZE_B, D),
        order=(1, 0),
    )
    w1_ptrs = tl.make_block_ptr(
        base=w1_ptr,
        shape=(H, D, E),
        strides=(stride_w1h, stride_w1d, stride_w1e),
        offsets=(pid_h, 0, 0),
        block_shape=(1, D, BLOCK_SIZE_E),
        order=(2, 1, 0),
    )
    w2_ptrs = tl.make_block_ptr(
        base=w2_ptr,
        shape=(H, E, D),
        strides=(stride_w2h, stride_w2e, stride_w2d),
        offsets=(pid_h, 0, 0),
        block_shape=(1, BLOCK_SIZE_E, D),
        order=(2, 1, 0),
    )
    o_ptrs = tl.make_block_ptr(
        base=o_ptr,
        shape=(H * B, D),
        strides=(stride_ob, stride_od),
        offsets=(pid_h * B + pid_b * BLOCK_SIZE_B, 0),
        block_shape=(BLOCK_SIZE_B, D),
        order=(1, 0),
    )
    x = tl.load(x_ptrs) # 1, BLOCK_SIZE_B, D  
    tl.static_print(x)
    o = tl.zeros((BLOCK_SIZE_B, D), dtype=tl.float32)
    tl.static_print(o)
    for e in range(0, tl.cdiv(E, BLOCK_SIZE_E)):
        z = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_E), dtype=tl.float32)
        # loop over D
        w1 = tl.load(w1_ptrs).reshape(D, BLOCK_SIZE_E)       # D, BLOCK_SIZE_E
        w2 = tl.load(w2_ptrs).reshape(BLOCK_SIZE_E, D)       # BLOCK_SIZE_E, D
        # tl.static_print(x, w1, z)
        z = tl.dot(x, w1, z)        # BLOCK_SIZE_B, BLOCK_SIZE_E
        # tl.static_print(z)
        # activation
        if ACTIVATION == "leaky_relu":
            z = leaky_relu(z).to(TARGET_TYPE)
        # accumulate with o
        # tl.static_print(z, w2, o)
        o = tl.dot(z, w2, o)        # 1, BLOCK_SIZE_B, D
        # tl.device_print('o_ptrs', o)
        # advance w1 and w2
        w1_ptrs = tl.advance(w1_ptrs, (0, 0, BLOCK_SIZE_E))
        w2_ptrs = tl.advance(w2_ptrs, (0, BLOCK_SIZE_E, 0))

    o = o.to(TARGET_TYPE)
    # o = o.to(TARGET_TYPE).reshape(BLOCK_SIZE_B, D)
    # x = x.to(TARGET_TYPE).reshape(BLOCK_SIZE_B, D)
    # tl.static_print(o)
    # store o
    tl.static_print(o_ptrs, o, x)
    tl.device_print('o_ptrs', x)
    tl.store(o_ptrs, x)
    # tl.store(o_ptrs, o)
    # tl.static_print(x_ptrs, x)
    # tl.store(x_ptrs, o)
    '''
    '''


In [4]:
def mlp_wide_triton(x, w1, w2, activation=""):
    # Check constraints.
    assert x.shape[0] == w1.shape[0], "Incompatible dimensions"
    assert x.shape[0] == w2.shape[0], "Incompatible dimensions"
    assert x.shape[2] == w1.shape[1], "Incompatible dimensions"
    assert w1.shape[2] == w2.shape[1], "Incompatible dimensions"
    assert w1.shape[1] == w2.shape[2], "Incompatible dimensions"
    assert x.is_contiguous(), "Matrix X must be contiguous"
    assert w1.is_contiguous(), "Matrix W1 must be contiguous"
    assert w2.is_contiguous(), "Matrix W2 must be contiguous"
    H, B, D = x.shape
    E = w1.shape[2]

    # Allocates output.
    o = torch.empty((H, B, D), device=x.device, dtype=x.dtype)
    # print(x.shape, w1.shape, w2.shape, o.shape)

    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (
        triton.cdiv(B, META['BLOCK_SIZE_B']),
        H,
    )
    mlp_wide_kernel[grid](
        x, w1, w2, o,
        H, B, D, E,
        x.stride(0), x.stride(1), x.stride(2),
        w1.stride(0), w1.stride(1), w1.stride(2),
        w2.stride(0), w2.stride(1), w2.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        ACTIVATION=activation
    )

    # print(o.shape)
    return o


In [5]:
def mlp_torch(x, w1, w2, activation=""):
    z = torch.bmm(x, w1)
    if activation == "leaky_relu":
        z = torch.nn.functional.leaky_relu(z)
    o = torch.bmm(z, w2)
    return o

In [6]:
def unit_test_simple():
    # torch.manual_seed(115)
    dtype = torch.bfloat16
    B = 64
    D = 64
    E = 1024
    HEAD = 16
    x = torch.randn((HEAD, B, D), device='cuda', dtype=dtype)
    w1 = torch.randn((HEAD, D, E), device='cuda', dtype=dtype)
    w2 = torch.randn((HEAD, E, D), device='cuda', dtype=dtype)
    triton_output = mlp_wide_triton(x, w1, w2, activation="leaky_relu")
    torch_output = mlp_torch(x, w1, w2, activation="leaky_relu")
    print(f"triton_output={triton_output, triton_output[0].shape}")
    print(f"torch_output={torch_output, torch_output.shape}")
    if torch.allclose(triton_output, torch_output, atol=3e-2, rtol=1e-2):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

    diff = np.abs(triton_output.to(torch.float32).cpu().numpy() - torch_output.to(torch.float32).cpu().numpy())
    print("max diff:",np.max(diff))
    print("mean diff:",np.mean(diff))

unit_test_simple()


bf16[constexpr[32], constexpr[64]]
fp32[constexpr[32], constexpr[64]]
pointer<<[32, 64], bf16>>[] bf16[constexpr[32], constexpr[64]] bf16[constexpr[32], constexpr[64]]
pid (0, 1, 0) idx (25,  0) o_ptrs: 16421
pid (0, 1, 0) idx (25,  1) o_ptrs: 15908
pid (0, 1, 0) idx (25,  2) o_ptrs: 4294950925
pid (0, 1, 0) idx (25,  3) o_ptrs: 16285
pid (0, 1, 0) idx (25,  4) o_ptrs: 4294950149
pid (0, 1, 0) idx (25,  5) o_ptrs: 16385
pid (0, 1, 0) idx (25,  6) o_ptrs: 16143
pid (0, 1, 0) idx (25,  7) o_ptrs: 15994
pid (0, 1, 0) idx (25,  8) o_ptrs: 4294950625
pid (0, 1, 0) idx (25,  9) o_ptrs: 4294950585
pid (0, 1, 0) idx (25, 10) o_ptrs: 4294950784
pid (0, 1, 0) idx (25, 11) o_ptrs: 16164
pid (0, 1, 0) idx (25, 12) o_ptrs: 4294950809
pid (0, 1, 0) idx (25, 13) o_ptrs: 15807
pid (0, 1, 0) idx (25, 14) o_ptrs: 4294950785
pid (0, 1, 0) idx (25, 15) o_ptrs: 15989
pid (0, 1, 0) idx (25, 16) o_ptrs: 16184
pid (0, 1, 0) idx (25, 17) o_ptrs: 4294950697
pid (0, 1, 0) idx (25, 18) o_ptrs: 16233
pid (0, 1, 0)

In [7]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['E'],  # Argument names to use as an x-axis for the plot
        x_vals=[
            # 2 ** i for i in range(5, 12)
            128 * i for i in range(2, 16)
        ],  # Different possible values for `x_name`
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot
        # Possible values for `line_arg`
        line_vals=['torch-32', 'triton-32', 'torch-64', 'triton-64', 'torch-128', 'triton-128'],
        # Label name for the lines
        line_names=["Torch-32", "Triton-32", "Torch-64", "Triton-64", "Torch-128", "Triton-128"],
        # Line styles
        styles=[('green', ':'), ('blue', ':'), ('green', '--'), ('blue', '--'), ('green', '-'), ('blue', '-')],
        ylabel="TFLOPS",  # Label name for the y-axis
        plot_name="mlp-performance",  # Name for the plot, used also as a file name for saving the plot.
        args={},
    )
)
def benchmark(E, provider):
    dtype = torch.bfloat16
    D = int(provider[provider.find('-') + 1:])
    HEAD = 768 // D
    B = 1024 * HEAD
    x = torch.randn((HEAD, B, D), device='cuda', dtype=dtype)
    w1 = torch.randn((HEAD, D, E), device='cuda', dtype=dtype)
    w2 = torch.randn((HEAD, E, D), device='cuda', dtype=dtype)
    quantiles = [0.5, 0.2, 0.8]
    if provider.startswith('torch'):
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: mlp_torch(x, w1, w2, activation="leaky_relu"), quantiles=quantiles)
    if provider.startswith('triton'):
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: mlp_wide_triton(x, w1, w2, activation="leaky_relu"), quantiles=quantiles)
    perf = lambda ms: 4 * B * D * E * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


In [7]:
benchmark.run(show_plots=True, print_data=True)


NameError: name 'benchmark' is not defined