Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream-k matmul implementation very slow mostly because of if/else inside for loop #1393

Open
pommedeterresautee opened this issue Mar 23, 2023 · 10 comments

Comments

@pommedeterresautee
Copy link
Contributor

pommedeterresautee commented Mar 23, 2023

These last days I tried to implement the paper Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU in Triton.

The main idea is to avoid wave quantization by evenly distributing work across SM using a lower granularity than a full tile.
I suppose you have seen/read it, if not I can share my notes about it to help understand what I hope to achieve.

It has been implemented in cutlass where it shows strong performances.

The implementation below in Triton works, aka it generates the right output (minus rounding issues when K is big because of half precision, nothing unexpected).

It corresponds to the two-tile Stream-K + data-parallel hybrid schedule from the paper:

  • during the first stage, each SM gets the equivalent work of more 1 full tile plus the work which would have led to quantization wave (the last wave in tiled base matmul), such that every SM does almost the same quantity of work. It leverages atomic add to sync work across SMs (at least that's how I implemented it, the paper is less precise on that point)
  • during the following stages/waves, it does tile-based matmul (one full tile per SM per wave of work).

The performance of this implementation is bad, from not good to really slow.

To ease in the understanding of the issue, we will focus on a setup which makes things comparable.

Using the g size parameter (number of SM to use during the first stage) in the kernel below, we can make the kernel work as tile-based matmul (no stream K):

  • if g is set to 0, it sends all the work to the else branch, and the perf are similar to those we get from the reference implementation of matmul in triton (which makes sense as the code is almost identical).
  • If we make g equal to the number of tiles in the problem we work on, it will always go to the then branch, and it also behaves as a tile-based matmul (one loop covers just one full tile, there is no atomic add, etc.).

Other values give us the hybrid stream-k.

We work on the following problem:

problem size: 256, 512, 32768
block size: 64, 64, 64
#tiles: 32
GPU : 3090 RTX
Triton 2.1 from main (pulled this morning)

Below we do just 2 runs: one run with g=32 and one with g=0.
Therefore it's expected to have similar timings.
I put PyTorch for comparaison, I guess it is doing some split K and the results are much better (expected).

The output of the script below is:

PyTorch 0.25088000297546387
hybrid stream-k (grid=32) 1.6752640008926392
tile matmul (grid=0) 0.5980160236358643

The reference implementation of triton matmul gives 0.57 on my GPU when forced to use 64,64,64 block sizes, so when grid=0 (meaning when we use else branch in kernel) results are ok.
Second line should have similar timings as it does the same kind of computation, but as you can see it is not the case -> This is the issue!

After a bunch of experiments, I excluded atomic add (for tests, not worrying about output, I replaced all of them by store() and the speed didn't change), variable initialization in the for loop as the cause of the slowness (for tests, not worrying about output, I moved everything out of the for loop, and speed improved by only 10%), and not "unrolled" for loop (for tests, I replaced start and end in range by literals so the compilers can safely unroll and parallelize stuff).

It appears that the biggest culprit are the if inside the for loop of the first stage:

if (current_iter + 1) % iters_per_tile == 0:  # (current_iter + 1) check if next iter is for a new tile
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    if current_iter + 1 - iters_per_tile >= start_iter:
        tl.store(C_, acc)
    else:
        tl.atomic_add(C_, acc)
    if end_iter != current_iter:
        acc *= 0.

The second one is to save the result and the last one is to reset the accumulator when we change of tile.
They are entered at most 2 times in the whole looping, but If I move that part outside of the for loop timings gets 3X better (and output is still correct when g = 32 of course).
It's also interesting to note that just commenting the last "if" to reset the accumulator (which is never entered in this setup) leads to significant speedup too (1.67 -> 1.06).

My 2 questions:

  • is there a way to write code to avoid this issue with the If?
  • is it a missing fast path you plan to implement?
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005

import torch
import triton
import triton.language as tl
from triton.compiler import init_cuda_utils

torch.manual_seed(123)

# ---------------------------------------------------------------------------
# Triton kernel
# ---------------------------------------------------------------------------


@triton.jit()
def _kernel(
        # input and output matrices
        A, B, C,
        # matrix dimensions
        M, N, K,
        # strides
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        # total number of iterations for Stream-K and other variables
        total_sm, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        # block dimensions and accumulator type
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0)
    # First wave: each SM/triton program does more than one tile of work
    if pid < total_sm:
        process_first_wave(
            A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
            pid, total_sm, total_iters_streamk, total_tiles_streamk, iters_per_tile,
            BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
        )
    else:
        # After first wave: classic blocking matmul
        process_classic_blocking(
            A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
            pid, total_sm, total_tiles_streamk, iters_per_tile,
            BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
        )


@triton.jit()
def process_first_wave(
        A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        pid, total_sm, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
):
    full = total_iters_streamk // total_sm
    remaining = total_iters_streamk % total_sm
    start_iter = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    for current_iter in range(start_iter, end_iter):  # iterate over K axis, M/N may change during iteration
        tile_id = current_iter // iters_per_tile
        pid_m = tile_id // tl.cdiv(N, BLOCK_N)
        pid_n = tile_id % tl.cdiv(N, BLOCK_N)
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
        rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
        rk = tl.arange(0, BLOCK_K)
        A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (
                current_iter % iters_per_tile)
        B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (
                current_iter % iters_per_tile)
        a = tl.load(A_)
        b = tl.load(B_)
        acc += tl.dot(a, b)
        if (current_iter + 1) % iters_per_tile == 0:  # (current_iter + 1) check if next iter is for a new tile
            C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
            if current_iter + 1 - iters_per_tile >= start_iter:
                tl.store(C_, acc)
            else:
                tl.atomic_add(C_, acc)
            if end_iter != current_iter:
                acc *= 0.

    # save last tile if there are some iterations leftovers
    if end_iter % iters_per_tile != 0:
        tile_id = tl.cdiv(end_iter, iters_per_tile) - 1
        pid_m = tile_id // tl.cdiv(N, BLOCK_N)
        pid_n = tile_id % tl.cdiv(N, BLOCK_N)
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
        tl.atomic_add(C_, acc)


@triton.jit()
def process_classic_blocking(
        A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        pid, total_sm, total_tiles_streamk, iters_per_tile,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
):
    pid = pid + (total_tiles_streamk - total_sm)  # first wave has done more tiles than there are SMs, we adjust pid
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid % tl.cdiv(N, BLOCK_N)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)


class _matmul(torch.autograd.Function):
    kernel = _kernel

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, grid_to_use: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // grid_to_use) * grid_to_use if grid_to_use > 0 else total_tiles
        if total_tiles >= grid_to_use:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= grid_to_use
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = grid_to_use + (total_tiles - total_tiles_streamk)  # grid

        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs=}")
        # allocates output
        if total_tiles_streamk > 0:
            # atomic add requires zero-initialized output
            c = torch.zeros((M, N), device=device, dtype=a.dtype)
        else:
            c = torch.empty((M, N), device=device, dtype=a.dtype)
        assert c.dtype == torch.float16
        _kernel[(total_programs,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_sm=grid_to_use,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_warps=16,
        )
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=64, BLK_N=64, BLK_K=64):
        return _matmul._call(a=a, b=b, grid_to_use=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
# init_cuda_utils()
# total_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
total_sm = 32  # number of tiles
m, n, k = 256, 512, 32768
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = False
C = matmul(A, B, total_sm, debug, 64, 64, 64)
expected = A @ B

assert torch.allclose(C, expected, atol=5e-1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm, debug))
    print(f"hybrid stream-k (grid={total_sm})", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", ms)

CC @Chillee who seems to have read the paper and may have run some Xps

@pommedeterresautee
Copy link
Contributor Author

Discussion is happening on Slack.

Following @ptillet advice, below kernel has been split in 2 independent parts to avoid reg spilling:

# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005

import torch
import triton
import triton.language as tl
from triton.compiler import init_cuda_utils


torch.manual_seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------


@triton.jit()
def process_first_wave(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk
    remaining = total_iters_streamk % total_programs_streamk
    start_iter = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    for current_iter in range(start_iter, end_iter):  # iterate over K axis, M/N may change during iteration
        tile_id = current_iter // iters_per_tile
        pid_m = tile_id // tl.cdiv(N, BLOCK_N)
        pid_n = tile_id % tl.cdiv(N, BLOCK_N)
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
        rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
        rk = tl.arange(0, BLOCK_K)
        A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (
                current_iter % iters_per_tile)
        B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (
                current_iter % iters_per_tile)
        a = tl.load(A_)
        b = tl.load(B_)
        acc += tl.dot(a, b)
        if (current_iter + 1) % iters_per_tile == 0:  # (current_iter + 1) check if next iter is for a new tile
            C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
            if current_iter + 1 - iters_per_tile >= start_iter:
                tl.store(C_, acc)
            else:
                tl.atomic_add(C_, acc)
            if end_iter != current_iter:
                acc *= 0.

    # save last tile if there are some iterations leftovers
    if end_iter % iters_per_tile != 0:
        tile_id = tl.cdiv(end_iter, iters_per_tile) - 1
        pid_m = tile_id // tl.cdiv(N, BLOCK_N)
        pid_n = tile_id % tl.cdiv(N, BLOCK_N)
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
        tl.atomic_add(C_, acc)


@triton.jit()
def process_classic_blocking(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0) + total_programs_streamk
    pid = pid + (total_tiles_streamk - total_programs_streamk) # first wave has done more tiles than there are SMs, we adjust pid
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid % tl.cdiv(N, BLOCK_N)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)


class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        if total_tiles_streamk > 0:
            # atomic add requires zero-initialized output
            c = torch.zeros((M, N), device=device, dtype=a.dtype)
        else:
            c = torch.empty((M, N), device=device, dtype=a.dtype)
        assert c.dtype == torch.float16
        k1 = process_first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=1,
            num_warps=16,
        )
        assert k1.n_spills == 0, f"register spilling detected: {k1.n_spills}"
        k2 = process_classic_blocking[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=3,
            num_warps=4,
        )
        assert k2.n_spills == 0, f"register spilling detected: {k2.n_spills}"
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
# init_cuda_utils()
# total_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
total_programs_streamk = 32  # number of tiles
m, n, k = 256, 512, 32768
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = False
C = matmul(A, B, total_programs_streamk, debug, 64, 64, 64)
expected = A @ B

assert torch.allclose(C, expected, atol=5e-1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug))
    print(f"hybrid stream-k (grid={total_programs_streamk})", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", ms)

It doesn't improve timings, but at least we can have dedicated num_warp / stages parameter at a lower granularity that way.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 23, 2023

cc @Chillee

@Chillee
Copy link
Contributor

Chillee commented Mar 24, 2023

I don't have anything concretely to add since my implementation is also slow :P But here's my implementation (I implemented the scratchpad + barrier approach outlined in the paper to avoid atomic adds, although it ends up being a bit slower than the atomic adds iirc).

http://ix.io/4qVV

@pommedeterresautee
Copy link
Contributor Author

Following advice of @ptillet on Slack I get the pointer update out of the for loop in the first kernel.

# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005

import torch
import triton
import triton.language as tl
from triton.compiler import init_cuda_utils

from kernl.debugger.debugger import triton_debug

torch.manual_seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------

#@triton_debug
@triton.jit()
def process_first_wave(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk
    remaining = total_iters_streamk % total_programs_streamk
    start_iter = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)

    tile_id = start_iter // iters_per_tile
    pid_m = tile_id // tl.cdiv(N, BLOCK_N)
    pid_n = tile_id % tl.cdiv(N, BLOCK_N)
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    #ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    #rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A_ = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (
            start_iter % iters_per_tile)
    B_ = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + BLOCK_K * stride_bk * (
            start_iter % iters_per_tile)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for current_iter in range(start_iter, end_iter):  # iterate over K axis, M/N may change during iteration
        a = tl.load(A_)
        b = tl.load(B_)
        acc += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk

        if (current_iter + 1) % iters_per_tile == 0:  # (current_iter + 1) check if next iter is for a new tile
            C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
            if current_iter + 1 - iters_per_tile >= start_iter:
                tl.store(C_, acc)
            else:
                tl.atomic_add(C_, acc)
            if end_iter != current_iter:
                acc *= 0.
            # update pointers
            tile_id_new = (current_iter+1) // iters_per_tile
            pid_m_new = tile_id_new // tl.cdiv(N, BLOCK_N)
            pid_n_new = tile_id_new % tl.cdiv(N, BLOCK_N)
            rm = rm + (pid_m_new - pid_m) * BLOCK_M
            rn = rn + (pid_n_new - pid_n) * BLOCK_N

            rk = tl.arange(0, BLOCK_K)
            A_ = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
            B_ = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)

    # save last tile if there are some iterations leftovers
    if end_iter % iters_per_tile != 0:
        tile_id = tl.cdiv(end_iter, iters_per_tile) - 1
        pid_m = tile_id // tl.cdiv(N, BLOCK_N)
        pid_n = tile_id % tl.cdiv(N, BLOCK_N)
        rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
        tl.atomic_add(C_, acc)


@triton.jit()
def process_classic_blocking(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0) + total_programs_streamk
    pid = pid + (total_tiles_streamk - total_programs_streamk) # first wave has done more tiles than there are SMs, we adjust pid
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid % tl.cdiv(N, BLOCK_N)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)


class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        if total_tiles_streamk > 0:
            # atomic add requires zero-initialized output
            c = torch.zeros((M, N), device=device, dtype=a.dtype)
        else:
            c = torch.empty((M, N), device=device, dtype=a.dtype)
        assert c.dtype == torch.float16
        k1 = process_first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=1,
            num_warps=16,
        )
        #assert k1.n_spills == 0, f"register spilling detected: {k1.n_spills}"
        k2 = process_classic_blocking[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=3,
            num_warps=4,
        )
        assert k2.n_spills == 0, f"register spilling detected: {k2.n_spills}"
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
# init_cuda_utils()
# total_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
total_programs_streamk = 82  # number of tiles
m, n, k = 256, 512, 32768//16
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = True
C = matmul(A, B, total_programs_streamk, debug, 64, 64, 64)
expected = A @ B

assert torch.allclose(C, expected, atol=5e-1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug))
    print(f"hybrid stream-k (grid={total_programs_streamk})", ms)

    ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", ms)

This version works with Kernl interpreter but doesn't compile (without any crash) on Triton

@pommedeterresautee
Copy link
Contributor Author

This version where the change of tile has been put out of the for loop brings on 3090 RTX around 4% improvement over PyTorch 2.0 matmul.
For reasons I have not yet understood, I can't reproduce those results on A100 (triton kernel is then slower than PyTorch on most shapes).

# sudo apt-get install zlib1g-dev
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005

import torch
import triton
import triton.language as tl
import random
from triton.compiler import init_cuda_utils


torch.manual_seed(123)
random.seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------

# iterate, multiply and accumulate over K axis
@triton.jit()
def mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter, end_iter,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    # If no work to do, return early
    if end_iter - start_iter == 0:  # TODO try == instead of == 0
        return
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    # where are we in the grid
    tile_id = start_iter // iters_per_tile
    pid_m = tile_id // tl.cdiv(N, BLOCK_N)
    pid_n = tile_id % tl.cdiv(N, BLOCK_N)
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)

    for current_iter in range(start_iter, end_iter):
        a = tl.load(A_)
        b = tl.load(B_)
        acc += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk

    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    if end_iter - start_iter == iters_per_tile and start_iter % iters_per_tile == 0:
        tl.store(C_, acc)
    else:
        tl.atomic_add(C_, acc)


@triton.jit()
def first_wave(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk  # iterations related to full wave
    remaining = total_iters_streamk % total_programs_streamk  # iterations related to last (partial) wave
    start_iter_1 = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    start_iter_2 = start_iter_1 + (iters_per_tile - start_iter_1 % iters_per_tile)
    start_iter_2 = tl.minimum(start_iter_2, end_iter)
    start_iter_3 = start_iter_2 + (iters_per_tile - start_iter_2 % iters_per_tile)
    start_iter_3 = tl.minimum(start_iter_3, end_iter)

    # finish tile another SM was working on (may be a complete tile)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_1, start_iter_2,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # do full tile (if there is enough work for that)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_2, start_iter_3,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # start a new tile (may be incomplete)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_3, end_iter,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

# similar to the reference matmul kernel
@triton.jit()
def full_tiles(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    pid = tl.program_id(0) + total_programs_streamk
    pid = pid + (total_tiles_streamk - total_programs_streamk) # first wave has done more tiles than there are SMs, we adjust pid
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid % tl.cdiv(N, BLOCK_N)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)

# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------

class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int, num_stages: int, num_warps: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        if total_tiles_streamk > 0:
            # atomic add requires zero-initialized output
            c = torch.zeros((M, N), device=device, dtype=a.dtype)
        else:
            c = torch.empty((M, N), device=device, dtype=a.dtype)
        assert c.dtype == torch.float16
        k1 = first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        assert k1.n_spills == 0, f"register spilling detected: {k1.n_spills}"
        k2 = full_tiles[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            num_stages=3,
            num_warps=4,
        )
        assert k2.n_spills == 0, f"register spilling detected: {k2.n_spills}"
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32, num_stages=3, num_warps=4):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, num_warps=num_warps, num_stages=num_stages)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
init_cuda_utils()
total_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
print(f"total SMs: {total_sm}")
total_programs_streamk = total_sm  # number of tiles to use in Stream-K first wave
m, n, k = 1536, 1792, 6016  # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = False
C = matmul(A, B, total_programs_streamk, debug, 64, 64, 64)
expected = A @ B

assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    triton_ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 64, 64, 64, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk}, block=64,64,64, warp=8)", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk}, block=128,128,32, warp=8)", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", triton_ms)


# ---------------------------------------------------------------------------
# Log-sampled benchmark
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
# also check 2 block sizes, 64x64x64 and 128x128x32, just in case...
num_samples = 32768
values = ((torch.logspace(torch.tensor(128).log2(), torch.tensor(8192).log2(), num_samples,
                          base=2) / 128).round() * 128).unique().tolist()
shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values]
shapes = random.sample(shapes, num_samples)
assert len(shapes) == num_samples

g = total_sm
results = []
for idx, (m, n, k) in enumerate(shapes):
    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(k, n, device="cuda", dtype=torch.float16)
    triton_ms_128_128_32, *_ = triton.testing.do_bench(lambda: matmul(A, B, g, False, 128, 128, 32, 3, 4))
    triton_ms_64_64_64, *_ = triton.testing.do_bench(lambda: matmul(A, B, g, False, 64, 64, 64, 3, 4))
    triton_ms = min(triton_ms_128_128_32, triton_ms_64_64_64)
    pytorch_ms, *_ = triton.testing.do_bench(lambda: A @ B)

    expected = A @ B
    C = matmul(A, B, g, False, 64, 64, 64)
    max_disc = (C - expected).abs().max().item()
    # for very large K, rounding due to half precision requires a large tolerance. We set it to 1.
    assert max_disc < 1, f"max: {max_disc}\n{C}\n{expected}"
    print(f"problem size: {m} {n} {k}")
    print(f"{triton_ms=:.3f} | {pytorch_ms=:.3f} with g={g}. ratio: {pytorch_ms / triton_ms:.3f}. best config: {'64' if triton_ms_64_64_64 < triton_ms_128_128_32 else '128'}")

    results.append((m, n, k, max_disc, pytorch_ms, triton_ms_128_128_32, triton_ms_64_64_64, pytorch_ms / triton_ms))

results.sort(key=lambda x: x[-1], reverse=False)

# ---------------------------------------------------------------------------
# Benchmark export
# ---------------------------------------------------------------------------

import json
with open("results_back_1000.json", "w") as f:
    json.dump(results, f, indent=4)

# compute the average speedup
speedups = [pytorch_ms / triton_ms_128_128_32 for _, _, _, _, pytorch_ms, triton_ms_128_128_32, triton_ms_64_64_64, _ in results]
print(f"average speedup: {sum(speedups) / len(speedups)}")

@pommedeterresautee
Copy link
Contributor Author

pommedeterresautee commented Mar 28, 2023

For documentation purpose, a version which use lock to not require output to be fill of zeros
It s approx as fast as the previous version.

# sudo apt-get install zlib1g-dev
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005


import torch
import triton
import triton.language as tl
import random
from triton.compiler import init_cuda_utils

torch.manual_seed(123)
random.seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------

# iterate, multiply and accumulate over K axis
@triton.jit()
def mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter, end_iter, locks, GROUP_M,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    # If no work to do, return early
    if end_iter == start_iter:
        return
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    # where are we in the grid
    tile_id = start_iter // iters_per_tile
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)

    for current_iter in range(start_iter, end_iter):
        a = tl.load(A_)
        b = tl.load(B_)
        acc += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk

    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    if end_iter - start_iter == iters_per_tile:
        tl.store(C_, acc)
    elif tl.atomic_cas(locks + tile_id, 0, 1) == 0:
            tl.store(C_, acc)
            tl.atomic_xchg(locks + tile_id, 2)
    else:
        while tl.atomic_cas(locks + tile_id, 2, 2) != 2:
            pass
        tl.atomic_add(C_, acc)


@triton.jit()
def first_wave(
        A, B, C,
        M, N, K,
        locks,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk  # iterations related to full wave
    remaining = total_iters_streamk % total_programs_streamk  # iterations related to last (partial) wave
    start_iter_1 = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    start_iter_2 = start_iter_1 + (iters_per_tile - start_iter_1 % iters_per_tile)
    start_iter_2 = tl.minimum(start_iter_2, end_iter)
    start_iter_3 = start_iter_2 + (iters_per_tile - start_iter_2 % iters_per_tile)
    start_iter_3 = tl.minimum(start_iter_3, end_iter)

    # finish tile another SM was working on (may be a complete tile)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_1, start_iter_2, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # do full tile (if there is enough work for that)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_2, start_iter_3, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # start a new tile (may be incomplete)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_3, end_iter, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

# similar to the reference matmul kernel
@triton.jit()
def full_tiles(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    # first wave has done more tiles than there are SMs, we adjust pid
    tile_id = tl.program_id(0) + total_tiles_streamk

    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size

    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)

# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------

class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int, num_stages: int, num_warps: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        c = torch.empty((M, N), device=device, dtype=a.dtype)
        locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
        assert c.dtype == torch.float16
        k1 = first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            locks,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if debug:
            print(f"{k1.n_regs} registers used, {k1.n_spills} spills")
        # assert k1.n_spills == 0, f"register spilling detected: {k1.n_spills}"
        k2 = full_tiles[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            num_stages=3,
            num_warps=4,
        )
        # assert k2.n_spills == 0, f"register spilling detected: {k2.n_spills}"
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32, num_stages=3, num_warps=4):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, num_warps=num_warps, num_stages=num_stages)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
init_cuda_utils()
total_sm = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"]
print(f"total SMs: {total_sm}")
total_programs_streamk = total_sm  # number of tiles to use in Stream-K first wave
m, n, k = 1536, 1792, 6016  # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = False
C = matmul(A, B, total_programs_streamk, debug, 64, 64, 64)
expected = A @ B

assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    triton_ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 256, 32, 3, 8))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    total_programs_streamk *= 2
    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", triton_ms)

if debug:
    exit(0)
# ---------------------------------------------------------------------------
# Log-sampled benchmark
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
num_samples = 32768  # 32768
step = 256
values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples,
                          base=2) / step).round() * step).unique().tolist()
shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values]
shapes = random.sample(shapes, num_samples)
assert len(shapes) == num_samples

results = []
for idx, (m, n, k) in enumerate(shapes):
    # print progress bar
    if idx % 10 == 0 and idx > 0:
        speedups = [ratio for *_, ratio in results]
        print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}")

    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(k, n, device="cuda", dtype=torch.float16)
    triton_ms_1sm, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm, False, 128, 128, 32, 3, 4))
    triton_2sm_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm*2, False, 128, 128, 32, 3, 4))
    pytorch_ms, *_ = triton.testing.do_bench(lambda: A @ B)

    expected = A @ B
    C = matmul(A, B, total_sm, False, 64, 64, 64)
    max_disc = (C - expected).abs().max().item()
    # for very large K, rounding due to half precision requires a large tolerance. We set it to 1.
    assert max_disc <= 1., f"max: {max_disc}\n{C}\n{expected}"

    results.append((m, n, k, max_disc, pytorch_ms, triton_ms_1sm, triton_2sm_ms, triton_ms_1sm < triton_2sm_ms, pytorch_ms / triton_ms_1sm))


results.sort(key=lambda x: x[-1], reverse=False)

# ---------------------------------------------------------------------------
# Benchmark export
# ---------------------------------------------------------------------------

import json
with open("results.json", "w") as f:
    json.dump(results, f, indent=4)

# speedup: 22740/32768 - average speedup: 1.052

@pommedeterresautee
Copy link
Contributor Author

This version has a simplified lock (one less atomic op) and gets a 10% speedup compared to the previous one (when same block size is used).

# sudo apt-get install zlib1g-dev
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005


import torch
import triton
import triton.language as tl
import random
from triton.runtime.driver.cuda import get_cuda_utils

torch.manual_seed(123)
random.seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------

# iterate, multiply and accumulate over K axis
@triton.jit()
def mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter, end_iter, locks, GROUP_M,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
):
    # If no work to do, return early
    if end_iter == start_iter:
        return
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    # where are we in the grid
    tile_id = start_iter // iters_per_tile
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)

    for current_iter in range(start_iter, end_iter):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    if end_iter % iters_per_tile == 0:  # last iteration of the tile always happens before its start on another SM
        tl.store(C, acc)
        tl.atomic_xchg(locks + tile_id, 1)
    else:
        while tl.atomic_cas(locks + tile_id, 0, 0) != 1:
            pass
        tl.atomic_add(C, acc)


@triton.jit()
def first_wave(
        A, B, C,
        M, N, K,
        locks,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk  # iterations related to full wave
    remaining = total_iters_streamk % total_programs_streamk  # iterations related to last (partial) wave
    start_iter_1 = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    start_iter_2 = start_iter_1 + (iters_per_tile - start_iter_1 % iters_per_tile)
    start_iter_2 = tl.minimum(start_iter_2, end_iter)
    start_iter_3 = start_iter_2 + (iters_per_tile - start_iter_2 % iters_per_tile)
    start_iter_3 = tl.minimum(start_iter_3, end_iter)

    # finish tile another SM was working on (may be a complete tile)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_1, start_iter_2, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # do full tile (if there is enough work for that)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_2, start_iter_3, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

    # start a new tile (may be incomplete)
    mac_loop(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        start_iter_3, end_iter, locks, GROUP_M,
        BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
    )

# similar to the reference matmul kernel
@triton.jit()
def full_tiles(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    # first wave has done more tiles than there are SMs, we adjust pid
    tile_id = tl.program_id(0) + total_tiles_streamk

    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size

    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A_ = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B_ = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc_ = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A_)
        b = tl.load(B_)
        acc_ += tl.dot(a, b)
        A_ += BLOCK_K * stride_ak
        B_ += BLOCK_K * stride_bk
    acc_ = acc_.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C_, acc_)

# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------

class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int, num_stages: int, num_warps: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        c = torch.empty((M, N), device=device, dtype=a.dtype)
        locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
        assert c.dtype == torch.float16
        k1 = first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            locks,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if debug:
            print(f"{k1.n_regs} registers used, {k1.n_spills} spills")
        #assert k1.n_spills == 0, f"register spilling detected: {k1.n_spills}"
        k2 = full_tiles[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            num_stages=3,
            num_warps=8,
        )
        if debug:
            print(f"{k2.n_regs} registers used, {k2.n_spills} spills")
        assert k2.n_spills == 0, f"register spilling detected: {k2.n_spills}"
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32, num_stages=3, num_warps=4):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, num_warps=num_warps, num_stages=num_stages)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
cuda_utils = get_cuda_utils()
total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"]
print(f"total SMs: {total_sm}")
total_programs_streamk = total_sm  # number of tiles to use in Stream-K first wave
m, n, k = 1536, 1792*2, 6016  # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = False
C = matmul(A, B, total_programs_streamk, debug, 128, 256, 64, 3, 8)
expected = A @ B

assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    triton_ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 256, 64, 3, 8))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk * 2, debug, 128, 128, 64, 3, 8))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", triton_ms)

if debug:
    exit(0)
# ---------------------------------------------------------------------------
# Log-sampled benchmark
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
num_samples = 32768  # 32768
step = 256
values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples,
                          base=2) / step).round() * step).unique().tolist()
shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values]
shapes = random.sample(shapes, num_samples)
assert len(shapes) == num_samples

results = []
for idx, (m, n, k) in enumerate(shapes):
    # print progress bar
    if idx % 10 == 0 and idx > 0:
        speedups = [ratio for *_, ratio in results]
        print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}")

    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(k, n, device="cuda", dtype=torch.float16)
    torch.cuda.synchronize()
    triton_ms_1sm, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm, False, 128, 256, 64, 3, 8))
    torch.cuda.synchronize()
    triton_ms_2sm, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm * 2, False, 128, 256, 64, 3, 8))
    torch.cuda.synchronize()
    triton_ms = min(triton_ms_1sm, triton_ms_2sm)
    pytorch_ms, *_ = triton.testing.do_bench(lambda: A @ B)
    torch.cuda.synchronize()
    expected = A @ B
    C = matmul(A, B, total_sm, False, 64, 64, 64)
    max_disc = (C - expected).abs().max().item()
    if pytorch_ms / triton_ms >= 1.1:
        print(f"{m}x{n}x{k}, speedup: {pytorch_ms / triton_ms_1sm:.3f}")
    # for very large K, rounding due to half precision requires a large tolerance. We set it to 1.
    assert max_disc <= 1., f"max: {max_disc}\n{C}\n{expected}"

    results.append((m, n, k, max_disc, pytorch_ms, triton_ms_1sm, triton_ms_2sm, triton_ms_1sm < triton_ms_2sm, pytorch_ms / triton_ms))


results.sort(key=lambda x: x[-1], reverse=False)

# ---------------------------------------------------------------------------
# Benchmark export
# ---------------------------------------------------------------------------

import json
with open("results.json", "w") as f:
    json.dump(results, f, indent=4)

# speedup: 22740/32768 - average speedup: 1.052

@pommedeterresautee
Copy link
Contributor Author

Another even more simplified lock which leverage the design of this kernel (removed an atomic ops), I am at just 0.9x PyTorch on average on A100 and much better on 3090RTX (according to nsight profiler, on 3090RTX I am mem bound, and not on A100, this explain that). Still some reg spilling even with 128 128 32 block size on the first wave kernel (22 spills).

# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null
# sudo update-initramfs -u -k all
# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly
# sudo apt-get install zlib1g-dev
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005


import torch
import triton
import triton.language as tl
import random
from triton.runtime.driver.cuda import get_cuda_utils

torch.manual_seed(123)
random.seed(123)

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------


@triton.jit()
def tile_swizzling(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M: tl.constexpr):
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size
    return pid_m, pid_n


@triton.jit()
def tile_classic(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M: tl.constexpr):
    pid_m = tile_id // tl.cdiv(N, BLOCK_N)
    pid_n = tile_id % tl.cdiv(N, BLOCK_N)
    return pid_m, pid_n


# iterate, multiply and accumulate over K axis
@triton.jit()
def mac_loop(A, B, C,
             M, N, K,
             locks,
             stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
             total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
             start_iter, end_iter,
             BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
             ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr, SWIZZLING: tl.constexpr):
    # If no work to do, return early
    if end_iter == start_iter:
        return
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    # where are we in the grid
    tile_id = start_iter // iters_per_tile
    if SWIZZLING:
        pid_m, pid_n = tile_swizzling(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)
    else:
        pid_m, pid_n = tile_classic(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)

    for current_iter in range(start_iter, end_iter):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    if end_iter % iters_per_tile == 0:  # last iteration of the tile always happens before its start on another SM
        tl.store(C, acc)
        if start_iter % iters_per_tile != 0:  # only if tile has been partially processed
            tl.store(locks + tile_id, 1)
    else:
        while tl.atomic_min(locks + tile_id, 1) != 1:
            pass
        tl.atomic_add(C, acc)


@triton.jit()
def first_wave(
        A, B, C,
        M, N, K,
        locks,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
        ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr, SWIZZLING: tl.constexpr,
):
    pid = tl.program_id(0)
    full = total_iters_streamk // total_programs_streamk  # iterations related to full wave
    remaining = total_iters_streamk % total_programs_streamk  # iterations related to last (partial) wave
    start_iter_1 = pid * full + tl.minimum(pid, remaining)
    end_iter = (pid + 1) * full + tl.minimum(pid + 1, remaining)
    start_iter_2 = start_iter_1 + (iters_per_tile - start_iter_1 % iters_per_tile)
    start_iter_2 = tl.minimum(start_iter_2, end_iter)
    start_iter_3 = start_iter_2 + (iters_per_tile - start_iter_2 % iters_per_tile)
    start_iter_3 = tl.minimum(start_iter_3, end_iter)

    # finish tile another SM was working on (may be a complete tile)
    mac_loop(A, B, C, M, N, K, locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
             total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile, start_iter_1,
             start_iter_2, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M, SWIZZLING)

    # do full tile (if there is enough work for that)
    mac_loop(A, B, C, M, N, K, locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
             total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile, start_iter_2,
             start_iter_3, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M, SWIZZLING)

    # start a new tile (may be incomplete)
    mac_loop(A, B, C, M, N, K, locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
             total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile, start_iter_3, end_iter,
             BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M, SWIZZLING)

# similar to the reference matmul kernel
@triton.jit()
def full_tiles(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_programs_streamk, total_iters_streamk, total_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr, SWIZZLING: tl.constexpr,
):
    # first wave has done more tiles than there are SMs, we adjust pid
    tile_id = tl.program_id(0) + total_tiles_streamk
    if SWIZZLING:
        pid_m, pid_n = tile_swizzling(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)
    else:
        pid_m, pid_n = tile_classic(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)

    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    acc = acc.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C, acc)

# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------

class _matmul(torch.autograd.Function):

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, debug: bool, BLK_M: int, BLK_N: int, BLK_K: int, swizzling: bool, num_stages: int, num_warps: int):
        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        total_tiles = total_blocks_M * total_blocks_N
        total_iters = total_tiles * iters_per_tile  # total work to do
        # tiles to be computed using classical blocking approach (data parallel in the paper)
        # if more SMs than tiles, will be 0
        total_blocking_tiles = (total_tiles // total_programs_streamk) * total_programs_streamk if total_programs_streamk > 0 else total_tiles
        if total_tiles >= total_programs_streamk:
            # for two-tile Stream-K + data-parallel in the paper
            total_blocking_tiles -= total_programs_streamk
        total_tiles_streamk = total_tiles - total_blocking_tiles
        total_iters_streamk = total_tiles_streamk * iters_per_tile

        total_programs = total_programs_streamk + (total_tiles - total_tiles_streamk)  # grid
        total_programs_classic = total_programs - total_programs_streamk
        if debug:
            print(f"m,n,k={M},{N},{K} ; BLK_M,BLK_N,BLK_K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_tiles=} * {iters_per_tile=} = {total_iters=}")
            print(f"{total_iters_streamk=}")
            print(f"{total_programs_streamk=} + {total_programs_classic=} = {total_programs=}")

        # allocates output
        c = torch.empty((M, N), device=device, dtype=a.dtype)
        locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
        assert c.dtype == torch.float16
        k1 = first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            locks,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            SWIZZLING=swizzling,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if debug:
            print(f"{k1.n_regs} registers used, {k1.n_spills} spills")
        k2 = full_tiles[(total_programs_classic,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_iters_streamk=total_iters_streamk,
            total_tiles_streamk=total_tiles_streamk,
            iters_per_tile=iters_per_tile,
            total_programs_streamk=total_programs_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=8,
            SWIZZLING=swizzling,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if debug:
            print(f"{k2.n_regs} registers used, {k2.n_spills} spills")
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, debug: bool = False, BLK_M=128, BLK_N=128, BLK_K=32, swizzling=True, num_stages=3, num_warps=4):
        return _matmul._call(a=a, b=b, total_programs_streamk=grid, debug=debug, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, swizzling=swizzling, num_warps=num_warps, num_stages=num_stages)


matmul = _matmul.apply

# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

device = torch.cuda.current_device()
cuda_utils = get_cuda_utils()
total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"]
print(f"total SMs: {total_sm}")
total_programs_streamk = total_sm  # number of tiles to use in Stream-K first wave
m, n, k = 1536, 1792, 6016  # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

debug = True
C = matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4)
expected = A @ B

assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

if not debug:
    triton_ms, *_ = triton.testing.do_bench(lambda: torch.matmul(A, B))
    print("PyTorch", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    total_programs_streamk *= 2
    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_programs_streamk, debug, 128, 128, 32, 3, 4))
    print(f"hybrid stream-k (grid={total_programs_streamk})", triton_ms)

    triton_ms, *_ = triton.testing.do_bench(lambda: matmul(A, B, 0, debug))
    print("tile matmul (grid=0)", triton_ms)

if debug:
    exit(0)
# ---------------------------------------------------------------------------
# Log-sampled benchmark
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
num_samples = 32768  # 32768
step = 256
values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples,
                          base=2) / step).round() * step).unique().tolist()
shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values]
shapes = random.sample(shapes, num_samples)
assert len(shapes) == num_samples

results = []
for idx, (m, n, k) in enumerate(shapes):
    # print progress bar
    if idx % 10 == 0 and idx > 0:
        speedups = [ratio for *_, ratio in results]
        print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}")

    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(k, n, device="cuda", dtype=torch.float16)
    torch.cuda.synchronize()
    triton_ms_1sm, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm, False, 128, 128, 32, 3, 4))
    torch.cuda.synchronize()
    triton_ms_2sm, *_ = triton.testing.do_bench(lambda: matmul(A, B, total_sm * 2, False, 128, 128, 32, 3, 4))
    torch.cuda.synchronize()
    triton_ms = min(triton_ms_1sm, triton_ms_2sm)
    pytorch_ms, *_ = triton.testing.do_bench(lambda: A @ B)
    torch.cuda.synchronize()

    expected = A @ B
    C = matmul(A, B, total_sm, False, 64, 64, 64)
    max_disc = (C - expected).abs().max().item()

    # for very large K, rounding due to half precision requires a large tolerance. We set it to 1.
    assert max_disc <= 1., f"max: {max_disc}\n{C}\n{expected}"

    results.append((m, n, k, max_disc, pytorch_ms, triton_ms_1sm, triton_ms_2sm, triton_ms_1sm < triton_ms_2sm, pytorch_ms / triton_ms_1sm))


results.sort(key=lambda x: x[-1], reverse=False)

# ---------------------------------------------------------------------------
# Benchmark export
# ---------------------------------------------------------------------------

import json
with open("results.json", "w") as f:
    json.dump(results, f, indent=4)

# speedup: 22740/32768 - average speedup: 1.052

@pommedeterresautee
Copy link
Contributor Author

New version, perf of 0.96x Pytorch (32K benchmarks).
No more register spilling 🥳
I am quite sure the for loop in mac_loop is not unrolled (there is no way to know at compilation time how many loops should be done).
I have tried to decompose in 2 loops, one being "unrollable" and one doing the remaining job, but it brings huge reg spilling and is slow. Open to new ideas.

# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null
# sudo update-initramfs -u -k all
# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly
# sudo apt-get install zlib1g-dev
# for reproductible experiments
# sudo nvidia-smi -pm 1 -i 0
# sudo nvidia-smi -i 0 -pl 350  # 400 for A100
# sudo nvidia-smi -i 0 -lgc 1005
from typing import Optional

import torch
import triton
import triton.language as tl
import random

from triton.runtime.driver import CudaUtils
import json

torch.manual_seed(123)
random.seed(123)

device = torch.cuda.current_device()
cuda_utils = CudaUtils()
total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"]
print(f"total SMs: {total_sm}")

# ---------------------------------------------------------------------------
# Triton kernels
# ---------------------------------------------------------------------------


@triton.jit()
def swizzle_tile(tile_id,
                 M, N, K,
                 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
                 GROUP_M: tl.constexpr
                 ):
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = tile_id // width
    group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (tile_id % group_size)
    pid_n = (tile_id % width) // group_size
    return pid_m, pid_n


@triton.jit()
def linear_tile(tile_id,
                M, N, K,
                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
                GROUP_M: tl.constexpr
                ):
    pid_m = tile_id // tl.cdiv(N, BLOCK_N)
    pid_n = tile_id % tl.cdiv(N, BLOCK_N)
    return pid_m, pid_n


# iterate, multiply and accumulate over K axis
@triton.jit()
def mac_loop(A, B, C,
             M, N, K,
             locks,
             stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
             iters_per_tile,
             start_iter, end_iter,
             BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
             ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr):

    # where are we in the grid
    tile_id = start_iter // iters_per_tile
    if GROUP_M  > 0:
        pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)
    else:
        pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

    for current_iter in range(start_iter, end_iter):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    if end_iter % iters_per_tile == 0:  # last iteration of the tile always happens before its start on another SM
        C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)  # compute inside the if/else to avoid spilling!
        tl.store(C_, acc)
        if start_iter % iters_per_tile != 0:  # only if tile has been partially processed
            tl.atomic_xchg(locks + tile_id, 1)
    else:
        while tl.atomic_cas(locks + tile_id, 1, 1) != 1:
            pass
        C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)  # compute inside the if/else to avoid spilling!
        tl.atomic_add(C_, acc)


@triton.jit()
def first_wave(
        A, B, C,
        M, N, K,
        locks,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    pid = tl.program_id(0)
    start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk)
    last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk)

    while start_iter < last_iter:
        end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter)
        mac_loop(A, B, C,
                 M, N, K,
                 locks,
                 stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
                 iters_per_tile,
                 start_iter, end_iter,
                 BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
                 GROUP_M,
                 )

        start_iter = end_iter


# similar to the reference matmul kernel
@triton.jit()
def full_tiles(
        A, B, C,
        M, N, K,
        stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
        total_tiles_streamk,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,
        GROUP_M: tl.constexpr,
):
    # first wave has done more tiles than there are SMs, we adjust pid
    tile_id = tl.program_id(0) + total_tiles_streamk
    if GROUP_M > 0:
        pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)
    else:
        pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)

    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    acc = acc.to(tl.float16)  # restore C.dtype.element_ty
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    tl.store(C, acc)


# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------

class matmul(torch.autograd.Function):

    _debug = False

    @staticmethod
    def set_debug(debug: bool):
        matmul._debug = debug

    @staticmethod
    def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int):
        device = a.device

        assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported"
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # accumulator types
        ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # compute grid (work to do per SM on the first wave)
        total_blocks_M = triton.cdiv(M, BLK_M)
        total_blocks_N = triton.cdiv(N, BLK_N)
        iters_per_tile = triton.cdiv(K, BLK_K)
        GROUP_M = 8  # 0 to disable swizzling
        total_tiles = total_blocks_M * total_blocks_N

        if total_programs_streamk > 0:  # Stream-K
            # last wave may occupy less than total_programs_streamk SMs
            total_tiles_streamk = total_tiles % total_programs_streamk
            # for two-tile Stream-K + data-parallel from original paper
            if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:
                total_tiles_streamk += total_programs_streamk
            # remaining tiles are computed using classical blocking
            total_blocking_tiles = total_tiles - total_tiles_streamk
            total_iters_streamk = total_tiles_streamk * iters_per_tile
            # iterations related to full waves
            total_full_tiles_streamk = total_iters_streamk // total_programs_streamk
            # iterations related to last (partial) wave
            total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk

        else:  # all tiles are computed using classical blocking
            total_blocking_tiles = total_tiles
            total_tiles_streamk = 0
            total_full_tiles_streamk = 0
            total_partial_tiles_streamk = 0
            total_iters_streamk = 0

        if matmul._debug:
            print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}")
            print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
            print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
            print(f"{total_programs_streamk=}")
            print(f"{total_blocking_tiles=}")
            print(f"{iters_per_tile=}")
            print(f"{total_iters_streamk=}")

        # allocates output
        c = torch.empty((M, N), device=device, dtype=a.dtype)
        # allocates locks to sync work accross SMs
        locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
        k1 = first_wave[(total_programs_streamk,)](
            a,
            b,
            c,
            M,
            N,
            K,
            locks,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_full_tiles_streamk=total_full_tiles_streamk,
            total_partial_tiles_streamk=total_partial_tiles_streamk,
            iters_per_tile=iters_per_tile,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=GROUP_M,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if matmul._debug:
            print(f"{k1.n_regs} registers used, {k1.n_spills} spills")
        k2 = full_tiles[(total_blocking_tiles,)](
            a,
            b,
            c,
            M,
            N,
            K,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            total_tiles_streamk=total_tiles_streamk,
            BLOCK_M=BLK_M,
            BLOCK_N=BLK_N,
            BLOCK_K=BLK_K,
            ACC_TYPE=ACC_TYPE,
            GROUP_M=GROUP_M,
            num_stages=num_stages,
            num_warps=num_warps,
        )
        if matmul._debug:
            print(f"{k2.n_regs} registers used, {k2.n_spills} spills")
        return c

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4):
        return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages)


# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------


m, n, k = 1536, 1792, 6016  # some problem size to test
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(k, n, device="cuda", dtype=torch.float16)

matmul.set_debug(True)
C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4)
matmul.set_debug(False)
expected = A @ B

assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"

# for debugging, uncomment the following line
# exit(0)

triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B))
print("PyTorch", triton_ms)

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, 128, 128, 32, True, 4, 4))
print(f"hybrid stream-k (grid={total_sm})", triton_ms)

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, 128, 128, 32, True, 4, 4))
print(f"hybrid stream-k (grid={total_sm * 2})", triton_ms)

triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, 128, 128, 32, True, 4, 4))
print("tile matmul (grid=0)", triton_ms)

# ---------------------------------------------------------------------------
# Log-sampled benchmark
# ---------------------------------------------------------------------------

# tried to reproduce the tests described in the paper
num_samples = 1000  # 32768
step = 256
values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist()
shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values]
shapes = random.sample(shapes, num_samples)
assert len(shapes) == num_samples

results = []
for idx, (m, n, k) in enumerate(shapes):
    # print progress bar
    if idx % 10 == 0 and idx > 0:
        speedups = [r["speedup"] for r in results]
        print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}")

    A = torch.randn(m, k, device="cuda", dtype=torch.float16)
    B = torch.randn(k, n, device="cuda", dtype=torch.float16)
    output: Optional[torch.Tensor] = None


    def wrapper_matmul(*args, **kwargs):
        global output
        output = matmul.apply(*args, **kwargs)
        return output


    expected = A @ B
    pytorch_ms = triton.testing.do_bench(lambda: A @ B)
    measures = list()
    for two_tiles in [True, False]:
        nb_sm = [total_sm, total_sm * 2]
        total_tile = (m // 128) * (n // 128)
        if total_tile < total_sm * 2:
            nb_sm.append(total_tile)
        nb_sm += random.sample(range(2, total_sm * 2, 2), 10)
        for sm in nb_sm:
            triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4))
            max_disc = (output - expected).abs().max().item()
            # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs.
            assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}"
            info = {
                "2 tiles": two_tiles,
                "sm": sm,
                "disc": max_disc,
                "triton_ms": triton_ms,
            }
            measures.append(info)
    best_triton_ms = min([m["triton_ms"] for m in measures])
    d = {
        "m": m,
        "n": n,
        "k": k,
        "triton": measures,
        "pytorch_ms": pytorch_ms,
        "speedup": pytorch_ms / best_triton_ms,
    }
    results.append(d)
    measures = list()

results.sort(key=lambda x: x["speedup"], reverse=False)

# ---------------------------------------------------------------------------
# Benchmark export
# ---------------------------------------------------------------------------

with open("results.json", "w") as f:
    json.dump(results, f, indent=4)

# 32760/32768 - average speedup: 0.962 (A100)
# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled)

@Ther-nullptr
Copy link

Will openai Triton officially add support for Stream-K algorithm?

zhanglx13 added a commit to ROCm/triton that referenced this issue Nov 29, 2023
LiyangLingIntel added a commit to intel/intel-xpu-backend-for-triton that referenced this issue Jul 11, 2024
This PR implemented Stream K GEMM kernel, ported Split K kernel from
[triton-lang/kernels](https://github.com/triton-lang/kernels/blob/main/kernels/matmul.py)
The Stream K implementation refers to its original paper
(https://arxiv.org/pdf/2407.00044) and example code in
triton-lang/triton#1393.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants