In [1]:
import torch
import triton
import triton.language as tl


@triton.jit
def matrix_multiplication_kernel(
    a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
):
    row = tl.program_id(0)
    col = tl.program_id(1)
    acc = tl.zeros((), dtype=tl.float32)
    for n in range(N):
        a_val = tl.load(a + row * stride_am + n * stride_an)
        b_val = tl.load(b + n * stride_bn + col * stride_bk)
        acc += a_val * b_val
    tl.store(c + row * stride_cm + col * stride_ck, acc)


# a, b, c are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int):
    stride_am, stride_an = N, 1
    stride_bn, stride_bk = K, 1
    stride_cm, stride_ck = K, 1

    grid = (M, K)
    matrix_multiplication_kernel[grid](
        a, b, c, M, N, K, stride_am, stride_an, stride_bn, stride_bk, stride_cm, stride_ck
    )


Why current kernel is slow:
Launches one Triton program per output element (grid=(M, K)), so each program computes only 1 scalar.
No tiling/blocking, so global memory reuse is poor (same a/b values reloaded many times).
No vectorized loads/stores or tensor-core-friendly structure.
No masking for edge tiles; only safe for clean bounds and contiguous assumptions.
No autotuning (BLOCK_M/N/K, num_warps, num_stages) for your GPU.

A better kernel should:
Use block tiling (e.g. BLOCK_M x BLOCK_N output tile, reduce over BLOCK_K chunks).
Keep an accumulator tile in registers (tl.zeros((BLOCK_M, BLOCK_N), tl.float32)).
Use masked loads for edge tiles.
Use grouped program ordering for better L2 locality.
Add @triton.autotune(...) configs per shape regime.

In [3]:
# Autotune tries multiple launch/block configurations and caches the best one
# for each (M, N, K) shape triple.
@triton.autotune(
    configs=[
        # Good default for medium shapes.
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_warps=4, num_stages=3),
        # Wider along M can help when output has many rows.
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_warps=8, num_stages=3),
        # Wider along N can help when output has many columns.
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_warps=8, num_stages=3),
        # Larger 2D tile for bigger matrices.
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_warps=8, num_stages=4),
    ],
    # Re-autotune when these runtime dimensions change.
    key=["M", "N", "K"],
)
@triton.jit
def matrix_multiplication_kernel_tiled(
    a_ptr,
    b_ptr,
    c_ptr,
    M,
    N,
    K,
    stride_am,
    stride_an,
    stride_bn,
    stride_bk,
    stride_cm,
    stride_ck,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_M: tl.constexpr,
):
    # Single 1D program-id space. We manually map each pid to a 2D output tile.
    pid = tl.program_id(axis=0)

    # Number of tiles along output rows (M dimension) and output cols (K dimension).
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(K, BLOCK_N)

    # Grouping several row-tiles together improves L2 reuse for B tiles.
    num_pid_in_group = GROUP_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M

    # Last group may be smaller than GROUP_M.
    group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_M)

    # Convert 1D pid -> (pid_m, pid_n) tile coordinates.
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # Row indices for this C tile, column indices for this C tile,
    # and reduction indices for one K-chunk (here reduction axis is N).
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # Register accumulator for one BLOCK_M x BLOCK_N output tile.
    # Keep float32 accumulation for better numeric stability.
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Walk the reduction dimension in BLOCK_K chunks.
    for k_start in range(0, N, BLOCK_K):
        k_offsets = k_start + offs_k

        # Pointer grids for A and B sub-tiles:
        # A tile shape = [BLOCK_M, BLOCK_K]
        # B tile shape = [BLOCK_K, BLOCK_N]
        a_ptrs = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_an
        b_ptrs = b_ptr + k_offsets[:, None] * stride_bn + offs_n[None, :] * stride_bk

        # Boundary masks so edge tiles are safe when M/N/K are not multiples
        # of tile sizes.
        a_mask = (offs_m[:, None] < M) & (k_offsets[None, :] < N)
        b_mask = (k_offsets[:, None] < N) & (offs_n[None, :] < K)

        # Out-of-bounds values are treated as zero so math stays correct.
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)

        # Matrix multiply-accumulate for this chunk.
        acc += tl.dot(a, b)

    # Compute output pointers for the C tile.
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_ck
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)

    # Store final tile to C, with boundary masking for partial edge tiles.
    tl.store(c_ptrs, acc, mask=c_mask)


# Faster tiled matmul; keeps your original solve(...) untouched.
def solve_tiled(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int):
    # Basic input validation to fail fast with useful errors.
    assert a.is_cuda and b.is_cuda and c.is_cuda, "All tensors must be CUDA tensors"
    assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous(), "Tensors must be contiguous"
    assert a.shape == (M, N), f"Expected a shape {(M, N)}, got {tuple(a.shape)}"
    assert b.shape == (N, K), f"Expected b shape {(N, K)}, got {tuple(b.shape)}"
    assert c.shape == (M, K), f"Expected c shape {(M, K)}, got {tuple(c.shape)}"

    # Number of launched programs = number of output tiles.
    grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(K, meta["BLOCK_N"]),)

    # Pass raw pointers + runtime sizes + physical tensor strides.
    # Strides make this work with row-major contiguous tensors explicitly.
    matrix_multiplication_kernel_tiled[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
    )

1. The Big Picture: Why do we use "Blocks"?

Imagine you have two huge walls of sticky notes (Matrix A and Matrix B) and you want to calculate a third wall (Matrix C).
If you use your original code (the "scalar" approach), you hire one worker for every single sticky note on Wall C. Each worker walks all the way over to Wall A, reads one note, walks to Wall B, reads one note, multiplies them, and writes it down. They do this hundreds of times. This is incredibly slow because walking back and forth to the walls (fetching data from GPU memory) takes forever.

Triton's Tiled Approach:
Instead of one worker per sticky note, we divide Wall C into large squares (e.g., $64 \times 64$ blocks). We assign a team of workers (a Triton program) to compute that entire block.

The team grabs a chunk of Wall A and a chunk of Wall B, brings it to their desk (ultra-fast GPU registers/SRAM), does a ton of math very quickly, and then goes back for the next chunks. Because they grab big chunks at once, they spend much less time walking and much more time calculating.

In [None]:
# 2D-program-id version.
# Easier to reason about than 1D grouped mapping because launch axes directly
# correspond to output-tile row/column coordinates.
# A -> MxN
# B -> NxK
# C -> MxK
@triton.autotune(
    configs=[
        # Balanced default tile for many medium-size shapes.
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3),
        # Taller output tile: may help when M is relatively large.
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8, num_stages=3),
        # Wider output tile: may help when K is relatively large.
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
        # Large square-ish tile for bigger workloads.
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=4),
    ],
    # Triton selects and caches the best config per runtime shape tuple.
    key=["M", "N", "K"],
)
@triton.jit
def matrix_multiplication_kernel_tiled_2d(
    a_ptr,
    b_ptr,
    c_ptr,
    M,
    N,
    K,
    stride_am,
    stride_an,
    stride_bn,
    stride_bk,
    stride_cm,
    stride_ck,
    BLOCK_M: tl.constexpr, # tile height rows of C
    BLOCK_N: tl.constexpr, # tile width cols of C
    BLOCK_K: tl.constexpr, # How much of K we process per step
):
    # With a 2D grid:
    # - axis 0 chooses output row-tile index
    # - axis 1 chooses output col-tile index
    """
    Step 1:
    When you run a kernel on a GPU, it launches thousands of identical "programs" at the same time.
    tl.program_id(0) tells this specific program which Row Block of Matrix C it is responsible for.
    tl.program_id(1) tells it which Column Block of Matrix C it is responsible for.

    Step 2:Figuring out the exact coordinates
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    If BLOCK_M is 64, and I am program row #2 (pid_m=2):
    $2 \times 64 = 128$. So my team is responsible for rows 128 through 191.
    offs_m literally creates a list of numbers: [128, 129, 130... 191].

    Step 3: The Accumulator 
    The team creates a blank $64 \times 64$ grid on their desk to hold the running totals of their multiplication. 
    They will add to this grid over and over again.

    Step 4: Walking across the walls(The For Loop)
    To calculate Matrix C, you have to multiply across the rows of Matrix A and down the columns of Matrix B. 
    The length of this journey is dimension N. 

    Step 5: Loading the data
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_an
    This looks scary but it just calculates memory addresses. GPU memory is actually a 1-dimensional line of data.
    To find the start of row 5, you have to jump 5 * length_of_row spaces. That jump size is called a stride.
    This line tells the GPU exactly where to find our $64 \times 32$ chunk of Matrix A and $32 \times 64$ chunk of Matrix B.

    Step 6: The Actual Math 
    This is where the magic happens. tl.dot uses the GPU's special AI hardware (Tensor Cores). It instantly multiplies the $64 \times 32$ chunk of A by the $32 \times 64$ chunk of B, and adds the result to our notebook (acc).

    Step 7: Writing it back to Wall C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_ck
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
    tl.store(c_ptrs, acc, mask=c_mask)

    Once the loop finishes walking across all of N, the team's notebook (acc) has the final answers. They figure out the memory addresses for Matrix C (c_ptrs), make sure they don't write past the edge (c_mask), and save the results back to global GPU memory (tl.store).

    """
    pid_m = tl.program_id(axis=0)  # which tile-row in C
    pid_n = tl.program_id(axis=1) # which tile-col in C  (this is actually K-tiles)

    # Global indices covered by this program instance.
    # offs_m: rows in C tile, offs_n: cols in C tile,
    # offs_k: chunk positions along reduction dimension (N).
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # Register tile accumulator.
    # Keep float32 accumulation for better numerical stability.
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Matmul reduction axis is N (A: [M,N], B: [N,K]).
    # Iterate over N in BLOCK_K chunks.
    for k_start in range(0, N, BLOCK_K):
        k_offsets = k_start + offs_k

        # Build pointer matrices for the current A/B subtile pair.
        # A subtile shape: [BLOCK_M, BLOCK_K]
        # B subtile shape: [BLOCK_K, BLOCK_N]
        a_ptrs = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_an
        b_ptrs = b_ptr + k_offsets[:, None] * stride_bn + offs_n[None, :] * stride_bk

        # Edge masks prevent out-of-bounds reads when dimensions are not
        # multiples of BLOCK sizes. Masked loads return 0.0.
        a_mask = (offs_m[:, None] < M) & (k_offsets[None, :] < N)
        b_mask = (k_offsets[:, None] < N) & (offs_n[None, :] < K)

        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)

        # Multiply-accumulate current chunk into output tile accumulator.
        acc += tl.dot(a, b)

    # Compute output pointers and store final tile with output bounds mask.
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_ck
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
    tl.store(c_ptrs, acc, mask=c_mask)


# Python wrapper for 2D tiled kernel.
def solve_tiled_2d(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int):
    # Validate assumptions used by this tutorial-style kernel.
    assert a.is_cuda and b.is_cuda and c.is_cuda, "All tensors must be CUDA tensors"
    assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous(), "Tensors must be contiguous"
    assert a.shape == (M, N), f"Expected a shape {(M, N)}, got {tuple(a.shape)}"
    assert b.shape == (N, K), f"Expected b shape {(N, K)}, got {tuple(b.shape)}"
    assert c.shape == (M, K), f"Expected c shape {(M, K)}, got {tuple(c.shape)}"

    # 2D launch grid:
    #  - first dimension: number of row tiles
    #  - second dimension: number of column tiles
    grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(K, meta["BLOCK_N"]))

    # Pass pointers, sizes, and physical strides so pointer math is correct.
    matrix_multiplication_kernel_tiled_2d[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
    )

In [6]:
def run_matrix_mult_tests():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA GPU is required to run Triton tests")

    torch.manual_seed(0)

    # (M, N, K) test cases: square and non-square
    test_shapes = [
        (1, 1, 1),
        (2, 3, 4),
        (4, 4, 4),
        (7, 5, 3),
        (16, 32, 8),
        (31, 17, 29),
    ]

    for m, n, k in test_shapes:
        a = torch.randn((m, n), device="cuda", dtype=torch.float32)
        b = torch.randn((n, k), device="cuda", dtype=torch.float32)
        c = torch.empty((m, k), device="cuda", dtype=torch.float32)

        solve(a, b, c, m, n, k)
        expected = torch.matmul(a, b)

        assert torch.allclose(c, expected, atol=1e-4, rtol=1e-4), (
            f"Mismatch for shape A=({m},{n}), B=({n},{k})"
        )

    # Deterministic sanity case
    a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device="cuda")
    b = torch.tensor([[5.0, 6.0], [7.0, 8.0]], device="cuda")
    c = torch.empty((2, 2), device="cuda")
    solve(a, b, c, 2, 2, 2)
    expected = torch.tensor([[19.0, 22.0], [43.0, 50.0]], device="cuda")
    assert torch.allclose(c, expected, atol=1e-5, rtol=1e-5)

    print("All matrix multiplication tests passed.")


run_matrix_mult_tests()

All matrix multiplication tests passed.
