In [None]:
import torch
import triton
import triton.language as tl
import torch

# Choose tile sizes
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
GROUP_M = 8  # improves L2 locality

def launch_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    M, K = A.shape
    K2, N = B.shape
    assert K == K2
    C = torch.empty((M, N), device=A.device, dtype=torch.float32)

    # Compute strides in element units (Triton expects element-strides, not bytes)
    stride_am, stride_ak = A.stride()
    stride_bk, stride_bn = B.stride()
    stride_cm, stride_cn = C.stride()

    # Grid: how many program instances along M and N
    grid = (
        triton.cdiv(M, BLOCK_M) * GROUP_M,  # dim-0 (will be grouped back into pid_m/pid_n)
        triton.cdiv(N, BLOCK_N),            # dim-1
    )

    matmul[grid](
        A, B, C, M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_M=GROUP_M,
        num_warps=4, num_stages=3,  # good starting points; autotune in practice
    )
    return C

ModuleNotFoundError: No module named 'triton'