In [1]:
import torch

import triton
import triton.language as tl
import matplotlib
import pandas as pd


In [2]:
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
    # x = x + 1
    return tl.where(x >= 0, x, 0.01 * x)


In [68]:
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
    ],
    key=['B', 'D', 'E'],
)
@triton.jit
def mlp_kernel(
    # Pointers to matrices
    x_ptr, w1_ptr, b1_ptr, w2_ptr, b2_ptr, z_ptr, o_ptr,
    # Matrix dimensions
    B, D, E,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_xb, stride_xd,
    stride_w1d, stride_w1e,
    stride_b1e,
    stride_w2e, stride_w2d,
    stride_b2d,
    stride_zb, stride_ze,
    stride_ob, stride_od,
    # Meta-parameters
    BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_E: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    ACTIVATION: tl.constexpr,
):
    """Kernel for computing the mlp
    Z = X @ W1 + b1, H = f(Z), O = H @ W2 + b2.
    - X has shape (B, D),
    - W1 has shape (D, E), b1 has shape (E)
    - W2 has shape (E, D), b2 has shape (D)
    - Z has shape (B, E), H has shape (B, E)
    - O has shape (B, D)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    pid_b = tl.program_id(axis=0)
    pid_d = tl.program_id(axis=1)
    TARGET_TYPE = x_ptr.type.element_ty
    # ----------------------------------------------------------

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    o = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_D), dtype=tl.float32)
    for e in range(0, tl.cdiv(E, BLOCK_SIZE_E)):
        z = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_E), dtype=tl.float32)
        # loop over D
        x_ptrs = tl.make_block_ptr(
            base=x_ptr,
            shape=(B, D),
            strides=(stride_xb, stride_xd),
            offsets=(pid_b * BLOCK_SIZE_B, 0),
            block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K),
            order=(1, 0),
        )
        w1_ptrs = tl.make_block_ptr(
            base=w1_ptr,
            shape=(D, E),
            strides=(stride_w1d, stride_w1e),
            offsets=(0, e * BLOCK_SIZE_E),
            block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_E),
            order=(1, 0),
        )
        for k in range(0, tl.cdiv(D, BLOCK_SIZE_K)):
            x = tl.load(x_ptrs)
            w1 = tl.load(w1_ptrs)
            z = tl.dot(x, w1, z)
            x_ptrs = tl.advance(x_ptrs, (0, BLOCK_SIZE_K))
            w1_ptrs = tl.advance(w1_ptrs, (BLOCK_SIZE_K, 0))
        # add bias
        b1_ptrs = tl.make_block_ptr(
            base=b1_ptr,
            shape=(E,),
            strides=(stride_b1e,),
            offsets=(e * BLOCK_SIZE_E,),
            block_shape=(BLOCK_SIZE_E,),
            order=(0,),
        )
        b1 = tl.load(b1_ptrs)
        z = z + b1
        z = z.to(TARGET_TYPE)
        # store z
        z_ptrs = tl.make_block_ptr(
            base=z_ptr,
            shape=(B, E),
            strides=(stride_zb, stride_ze),
            offsets=(pid_b * BLOCK_SIZE_B, e * BLOCK_SIZE_E),
            block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_E),
            order=(1, 0),
        )
        tl.store(z_ptrs, z)
        # You can fuse arbitrary activation functions here
        h = z
        if ACTIVATION == "leaky_relu":
            h = leaky_relu(z).to(TARGET_TYPE)
        # loop over W2
        o_d = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_D), dtype=tl.float32)
        w2_ptrs = tl.make_block_ptr(
            base=w2_ptr,
            shape=(E, D),
            strides=(stride_w2e, stride_w2d),
            offsets=(e * BLOCK_SIZE_E, pid_d * BLOCK_SIZE_D),
            block_shape=(BLOCK_SIZE_E, BLOCK_SIZE_D),
            order=(1, 0),
        )
        w2 = tl.load(w2_ptrs)
        # tl.static_print(h)
        # tl.static_print(w2)
        o = tl.dot(h, w2, o)
        tl.static_print('o', o)
    # add bias
    b2_ptrs = tl.make_block_ptr(
        base=b2_ptr,
        shape=(D,),
        strides=(stride_b2d,),
        offsets=(pid_d * BLOCK_SIZE_D,),
        block_shape=(BLOCK_SIZE_D,),
        order=(0,),
    )
    b2 = tl.load(b2_ptrs)
    o = o + b2
    o = o.to(TARGET_TYPE)
    # store o
    o_ptrs = tl.make_block_ptr(
        base=o_ptr,
        shape=(B, D),
        strides=(stride_ob, stride_od),
        offsets=(pid_b * BLOCK_SIZE_B, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),
        order=(1, 0),
    )
    tl.store(o_ptrs, o)



In [69]:
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 256, 'BLOCK_SIZE_E': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 128, 'BLOCK_SIZE_E': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 128, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_B': 64, 'BLOCK_SIZE_D': 32, 'BLOCK_SIZE_E': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_B': 32, 'BLOCK_SIZE_D': 64, 'BLOCK_SIZE_E': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
    ],
    key=['B', 'D', 'E'],
)
@triton.jit
def mlp_simple_kernel(
    # Pointers to matrices
    x_ptr, w1_ptr, b1_ptr, w2_ptr, b2_ptr, z_ptr, o_ptr,
    # Matrix dimensions
    B, D: tl.constexpr, E,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_xb, stride_xd,
    stride_w1d, stride_w1e,
    stride_b1e,
    stride_w2e, stride_w2d,
    stride_b2d,
    stride_zb, stride_ze,
    stride_ob, stride_od,
    # Meta-parameters
    BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_E: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    ACTIVATION: tl.constexpr,
):
    """Kernel for computing the mlp
    Z = X @ W1 + b1, H = f(Z), O = H @ W2 + b2.
    - X has shape (B, D),
    - W1 has shape (D, E), b1 has shape (E)
    - W2 has shape (E, D), b2 has shape (D)
    - Z has shape (B, E), H has shape (B, E)
    - O has shape (B, D)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    pid_b = pid
    # tl.static_print(x_ptr.dtype)
    # tl.static_print(x_ptr.__dict__)
    # tl.static_print(x_ptr.dtype.__dict__)
    # tl.static_print(x_ptr.dtype.element_ty)
    TARGET_TYPE = x_ptr.dtype.element_ty
    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    offs_b_blk = (pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B))
    offs_d = tl.arange(0, D)
    # x_ptrs = x_ptr + (offs_b_blk[:, None] * stride_xb + offs_d[None, :] * stride_xd)
    # x = tl.load(x_ptrs, mask=(offs_b_blk[:, None] < B) and (offs_d[None, :] < D), other=0.0)
    # tl.static_print(x_ptrs)
    # tl.static_print(w1_ptrs)
    # tl.static_print(b1_ptrs)
    # tl.static_print(w2_ptrs)
    # tl.static_print(b2_ptrs)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    o = tl.zeros((BLOCK_SIZE_B, 0), dtype=tl.float32)
    o_d = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_D), dtype=tl.float32)
    for e in range(0, tl.cdiv(E, BLOCK_SIZE_E)):
        z = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_E), dtype=tl.float32)
        # loop over D
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        offs_b_blk = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
        offs_e_blk = e * BLOCK_SIZE_E + tl.arange(0, BLOCK_SIZE_E)
        x_ptrs = x_ptr + (offs_b_blk[:, None] * stride_xb + offs_k[None, :] * stride_xd)
        w1_ptrs = w1_ptr + (offs_k[:, None] * stride_w1d + offs_e_blk[None, :] * stride_w1e)
        for k in range(0, tl.cdiv(D, BLOCK_SIZE_K)):
            x = tl.load(x_ptrs, mask=(offs_b_blk[:, None] < B) & (offs_k[None, :] + k * BLOCK_SIZE_K < D), other=0.0)
            w1 = tl.load(w1_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < D) & (offs_e_blk[None, :] < E), other=0.0)
            z = tl.dot(x, w1, z)
            x_ptrs += BLOCK_SIZE_K * stride_xd
            w1_ptrs += BLOCK_SIZE_K * stride_w1d
        # add bias
        b1_ptrs = b1_ptr + (offs_e_blk[:] * stride_b1e)
        b1 = tl.load(b1_ptrs, mask=(offs_e_blk[:] < E), other=0.0)
        z = z + b1[None, :]
        z = z.to(TARGET_TYPE)
        # store z
        offs_z_blk = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
        z_ptrs = z_ptr + (offs_z_blk[:, None] * stride_zb) + (offs_e_blk[None, :] * stride_ze)
        tl.store(z_ptrs, z, mask=(offs_b_blk[:, None] < B) and (offs_e_blk[None, :] < E))
        # You can fuse arbitrary activation functions here
        h = z
        if ACTIVATION == "leaky_relu":
            h = leaky_relu(z)
            h = h.to(TARGET_TYPE)
        # loop over W2
        # for k in range(0, tl.cdiv(D, BLOCK_SIZE_K)):
        #     offs_k = tl.arange(0, BLOCK_SIZE_K)
        #     w2_ptrs = w2_ptr + (offs_e_blk[:, None] * stride_w2e + offs_k[None, :] * stride_w2d)
        #     w2 = tl.load(w2_ptrs, mask=(offs_e_blk[:, None] < E) & (offs_k[None, :] + k * BLOCK_SIZE_K < D), other=0.0)
        #     o[:, offs_k] = tl.dot(h, w2, o[:, offs_k])
        w2_ptrs = w2_ptr + (offs_e_blk[:, None] * stride_w2e + offs_d[None, :] * stride_w2d)
        w2 = tl.load(w2_ptrs, mask=(offs_e_blk[:, None] < E) & (offs_d[None, :] < D), other=0.0)
        o = tl.dot(h, w2, o)
    # add bias
    b2_ptrs = b2_ptr + (offs_d[:] * stride_b2d)
    b2 = tl.load(b2_ptrs, mask=(offs_d[:] < D), other=0.0)
    o = o + b2[None, :]
    o = o.to(TARGET_TYPE)
    # store o
    offs_o_blk = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
    o_ptrs = o_ptr + (offs_o_blk[:, None] * stride_ob + offs_d[None, :] * stride_od)
    tl.store(o_ptrs, o, mask=(offs_o_blk[:, None] < B) & (offs_d[None, :] < D))



In [77]:
def mlp_triton(x, w1, b1, w2, b2, activation=""):
    # Check constraints.
    assert x.shape[1] == w1.shape[0], "Incompatible dimensions"
    assert w1.shape[1] == w2.shape[0], "Incompatible dimensions"
    assert w1.shape[1] == b1.shape[0], "Incompatible dimensions"
    assert w2.shape[1] == b2.shape[0], "Incompatible dimensions"
    assert w2.shape[1] == b2.shape[0], "Incompatible dimensions"
    assert x.is_contiguous(), "Matrix X must be contiguous"
    assert w1.is_contiguous(), "Matrix W1 must be contiguous"
    assert b1.is_contiguous(), "Matrix B1 must be contiguous"
    assert w2.is_contiguous(), "Matrix W2 must be contiguous"
    assert b2.is_contiguous(), "Matrix B2 must be contiguous"
    B, D = x.shape
    E = w1.shape[1]

    # Allocates output.
    z = torch.empty((B, E), device=x.device, dtype=x.dtype)
    o = torch.empty((B, D), device=x.device, dtype=x.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (
        triton.cdiv(B, META['BLOCK_SIZE_B']),
        triton.cdiv(D, META['BLOCK_SIZE_D']),
    )
    mlp_kernel[grid](
        x, w1, b1, w2, b2, z, o,
        B, D, E,
        x.stride(0), x.stride(1),
        w1.stride(0), w1.stride(1),
        b1.stride(0),
        w2.stride(0), w2.stride(1),
        b2.stride(0),
        z.stride(0), z.stride(1),
        o.stride(0), o.stride(1),
        ACTIVATION=activation
    )
    return o, z


In [78]:
def mlp_torch(x, w1, b1, w2, b2, activation=""):
    z = torch.matmul(x, w1) + b1
    if activation == "leaky_relu":
        z = torch.nn.functional.leaky_relu(z)
    o = torch.matmul(z, w2) + b2
    return o

In [83]:
def unit_test_simple():
    torch.manual_seed(125)
    dtype = torch.bfloat16
    x = torch.randn((256, 768), device='cuda', dtype=dtype)
    w1 = torch.randn((768, 1024), device='cuda', dtype=dtype)
    b1 = torch.zeros(1024, device='cuda', dtype=dtype)
    w2 = torch.randn((1024, 768), device='cuda', dtype=dtype)
    b2 = torch.randn(768, device='cuda', dtype=dtype)
    triton_output = mlp_triton(x, w1, b1, w2, b2, activation="leaky_relu")
    torch_output = mlp_torch(x, w1, b1, w2, b2, activation="leaky_relu")
    print(f"triton_output={triton_output[0], triton_output[0].shape}")
    print(f"torch_output={torch_output, torch_output.shape}")
    if torch.allclose(triton_output[0], torch_output, atol=1e-2, rtol=1e-2):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

unit_test_simple()


triton_output=(tensor([[  286.0000,   328.0000,  -159.0000,  ..., -1224.0000,   242.0000,
           162.0000],
        [  888.0000,   584.0000,    23.7500,  ...,   219.0000,   428.0000,
           580.0000],
        [   20.7500,   940.0000,  -282.0000,  ...,   173.0000,   824.0000,
          1424.0000],
        ...,
        [   49.7500,   676.0000,  -370.0000,  ...,  -580.0000,  -440.0000,
           416.0000],
        [  552.0000,  -274.0000,  -444.0000,  ...,  -728.0000,   -19.8750,
           692.0000],
        [  968.0000,  1256.0000,   418.0000,  ...,  -524.0000,  -432.0000,
           394.0000]], device='cuda:0', dtype=torch.bfloat16), torch.Size([256, 768]))
torch_output=(tensor([[  286.0000,   328.0000,  -160.0000,  ..., -1224.0000,   242.0000,
           161.0000],
        [  888.0000,   584.0000,    23.7500,  ...,   220.0000,   428.0000,
           580.0000],
        [   20.8750,   940.0000,  -282.0000,  ...,   173.0000,   824.0000,
          1424.0000],
        ...,
       

In [74]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['B', 'D', 'E'],  # Argument names to use as an x-axis for the plot
        x_vals=[
            64 * i for i in range(2, 32)
        ],  # Different possible values for `x_name`
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot
        # Possible values for `line_arg`
        line_vals=['torch', 'triton'],
        # Label name for the lines
        line_names=["Torch", "Triton"],
        # Line styles
        styles=[('green', '-'), ('blue', '-')],
        ylabel="TFLOPS",  # Label name for the y-axis
        plot_name="mlp-performance",  # Name for the plot, used also as a file name for saving the plot.
        args={},
    )
)
def benchmark(B, D, E, provider):
    dtype = torch.bfloat16
    x = torch.randn((B, D), device='cuda', dtype=dtype)
    w1 = torch.randn((D, E), device='cuda', dtype=dtype)
    b1 = torch.zeros(E, device='cuda', dtype=dtype)
    w2 = torch.randn((E, D), device='cuda', dtype=dtype)
    b2 = torch.randn(D, device='cuda', dtype=dtype)
    quantiles = [0.5, 0.2, 0.8]
    if provider.startswith('torch'):
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: mlp_torch(x, w1, b1, w2, b2, activation="leaky_relu"), quantiles=quantiles)
    if provider.startswith('triton'):
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: mlp_simple(x, w1, b1, w2, b2, activation="leaky_relu"), quantiles=quantiles)
    perf = lambda ms: 2 * B * D * E * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


In [65]:
benchmark.run(show_plots=True, print_data=True)


NameError: name 'mlp_simple' is not defined