In [None]:
import triton
import torch
import os
import triton.language as tl
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack


DEVICE = torch.device("cuda:0")

In [None]:
# Here we define the torch code

def torch_add(x, y):
    return x + y

In [None]:
# Here we define the torch compile code

@torch.compile
def torch_compile_add(x, y):
    return x + y

In [None]:
# Here we define the triton code

@triton.jit
def add_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    # todo

def triton_add(x, y):
    output = torch.empty_like(x)
    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    # todo
    return output

In [None]:
# Here we define the cute code

os.environ["PYTHONUNBUFFERED"] = "1"

@cute.kernel
def vectorized_elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
):
    # todo
    pass

@cute.jit
def vectorized_elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor
):
    # todo
    pass

a = torch.randn((2048, 2048), device="cuda", dtype=torch.float32)
b = torch.randn((2048, 2048), device="cuda", dtype=torch.float32)
c = torch.zeros((2048, 2048), device="cuda", dtype=torch.float32)
a_ = from_dlpack(a, assumed_align=16).mark_layout_dynamic()
b_ = from_dlpack(b, assumed_align=16).mark_layout_dynamic()
c_ = from_dlpack(c, assumed_align=16).mark_layout_dynamic()
cute_dsl_add_compiled = cute.compile(vectorized_elementwise_add, a_, b_, c_)

def cute_dsl_add(x, y):
    out = torch.zeros_like(x)
    cute_x = from_dlpack(x, assumed_align=16).mark_layout_dynamic()
    cute_y = from_dlpack(y, assumed_align=16).mark_layout_dynamic()
    cute_out = from_dlpack(out, assumed_align=16).mark_layout_dynamic()
    cute_dsl_add_compiled(cute_x, cute_y, cute_out)
    return out

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(12, 28, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch', 'torch_compile', 'cute_dsl'],  # Possible values for `line_arg`.
        line_names=['Triton', 'Torch', 'Torch Compile', 'Cute dsl'],  # Label name for the lines.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='vector-add-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(size, provider):
    x = torch.rand(size, device=DEVICE, dtype=torch.float32)
    y = torch.rand(size, device=DEVICE, dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    target_function = globals()[provider + "_add"]
    z = target_function(x, y)
    if torch.norm(z - (x + y)) > 1e-7 :
        print(provider, torch.norm(z - (x + y)), x.shape)
        print(z)
        print(x + y)

    ms, min_ms, max_ms = triton.testing.do_bench(lambda: target_function(x, y), quantiles=quantiles)
    # gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    # todo for matrix add
    return gbps(ms), gbps(max_ms), gbps(min_ms)

In [None]:
benchmark.run(print_data=True, show_plots=True)