# Based on [transform-mma-sync-matmul-f16-f16-accum.mlir](https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6)

# Download mlir-python-bindings with CUDA support

In [None]:
import os
BRANCH = os.getenv("BRANCH", "main")
os.environ["SCRIPT_ADDRESS"] = f"https://raw.githubusercontent.com/makslevental/mlir-python-extras/refs/heads/{BRANCH}/scripts/get_latest_bindings.py"

In [None]:
%%bash
curl $SCRIPT_ADDRESS -o get_latest_bindings.py
latest_cuda_version=$(python get_latest_bindings.py "cuda") && pip install -q mlir_python_bindings==$latest_cuda_version -f https://makslevental.github.io/wheels
pip install -q git+https://github.com/makslevental/mlir-python-extras@$BRANCH

# Boilerplate

In [None]:
from pathlib import Path

import mlir.extras.types as T
from mlir.dialects import builtin
from mlir.dialects.transform import any_op_t
from mlir.dialects.transform.extras import named_sequence
from mlir.dialects.transform.structured import MatchInterfaceEnum
from mlir.ir import StringAttr, UnitAttr

from mlir import _mlir_libs
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
from mlir.extras.dialects.ext import arith, memref, scf, gpu
from mlir.extras.dialects.ext import linalg
from mlir.extras.dialects.ext import transform
from mlir.extras.dialects.ext.func import func
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend
from mlir.extras.util import find_ops

CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
assert CUDA_RUNTIME_LIB_PATH.exists()

# Context

In [None]:
ctx = RAIIMLIRContext()
module = ExplicitlyManagedModule()

# Kernel and helper code

In [None]:
range_ = scf.range_

M, K, N = 16, 16, 8

# forward reference...
# TODO(max): figure out closures...
printMemrefF32_ = []


@func
def compute_linspace_val(ridx: T.index(), cidx: T.index(), stride_cidx: T.index()):
    r = arith.index_cast(ridx, to=T.i32())
    c = arith.index_cast(cidx, to=T.i32())
    stride_c = arith.index_cast(stride_cidx, to=T.i32())
    v2 = r * stride_c
    v3 = c + v2
    v4 = arith.sitofp(T.f16(), v3)
    factor = arith.constant(64.0, T.f16())
    v5 = arith.divf(v4, factor)
    return v5


@func
@canonicalize(using=scf.canonicalizer)
def print_lhs_as_memref_32(lhs: T.memref(M, K, T.f16())):
    M = memref.dim(lhs, 0)
    K = memref.dim(lhs, 1)
    tmp_alloc = memref.alloc((M, K), T.f32())
    for m in range_(0, M):
        for k in range_(0, K):
            f16 = lhs[m, k]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[m, k] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def print_rhs_as_memref_32(rhs: T.memref(K, N, T.f16())):
    K = memref.dim(rhs, 0)
    N = memref.dim(rhs, 1)
    tmp_alloc = memref.alloc((K, N), T.f32())
    for k in range_(0, K):
        for n in range_(0, N):
            f16 = rhs[k, n]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[k, n] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def print_res_as_memref_32(res: T.memref(M, N, T.f16())):
    c0 = arith.constant(0, index=True)
    c1 = arith.constant(1, index=True)
    M = memref.dim(res, c0)
    N = memref.dim(res, c1)
    tmp_alloc = memref.alloc((M, N), T.f32())
    for m in range_(0, M):
        for n in range_(0, N):
            f16 = res[m, n]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[m, n] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def main():
    lhs = memref.alloc((M, K), T.f16())
    rhs = memref.alloc((K, N), T.f16())
    res = memref.alloc((M, N), T.f16())

    M_ = memref.dim(res, 0)
    N_ = memref.dim(res, 1)
    K_ = memref.dim(lhs, 1)

    _f1 = arith.constant(1.0e00, T.f16())
    _f0 = arith.constant(0.0e00, T.f16())
    _c32 = arith.constant(32, T.index())

    # Initialize the lhs matrix with a linspace function.
    for r in range_(0, M_):
        for c in range_(0, K_):
            idx = compute_linspace_val(r, c, K_)
            lhs[r, c] = idx

    # Initialize the rhs matrix with a linspace function.
    for r in range_(0, K_):
        for c in range_(0, N_):
            idx = compute_linspace_val(r, c, N_)
            rhs[r, c] = idx

    # Initialize the res matrix with a linspace function.
    for r in range_(0, M_):
        for c in range_(0, N_):
            idx = compute_linspace_val(r, c, N_)
            res[r, c] = idx

    ulhs = memref.cast(T.memref(T.f16()), lhs)
    urhs = memref.cast(T.memref(T.f16()), rhs)
    ures = memref.cast(T.memref(T.f16()), res)
    gpu.host_register(ulhs)
    gpu.host_register(urhs)
    gpu.host_register(ures)

    print_lhs_as_memref_32(lhs)
    print_rhs_as_memref_32(rhs)

    @gpu.launch(grid_size=[1, 1, 1], block_size=[32, 1, 1])
    def kernel(bx, by, bz, tx, ty, tz, *grid_block_sizes):
        linalg.matmul(lhs, rhs, res)

    print_res_as_memref_32(res)


@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
def payload():
    compute_linspace_val.emit()

    @func
    def printMemrefF32(x: T.memref(T.f32())):
        ...

    printMemrefF32_.append(printMemrefF32)

    print_lhs_as_memref_32.emit()
    print_rhs_as_memref_32.emit()
    print_res_as_memref_32.emit()
    main.emit()

# Transform schedule


In [None]:
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
def mod_transform():
    @named_sequence(
        "main", [any_op_t()], [], arg_attrs=[{"transform.readonly": UnitAttr.get()}]
    )
    def main(module: any_op_t()):
        matmul = transform.match(module, ["linalg.matmul"])
        transform.nvgpu.rewrite_matmul_as_mma_sync(matmul)
        # clean up to simplify test below...
        all_loops = transform.match(
            module, interface=MatchInterfaceEnum.LoopLikeInterface
        )
        transform.apply_licm(all_loops)
        transform.apply_cse(module)

# "Finish" the module

In [None]:
module = module.finish()
print(module)
assert module.operation.verify()

# Execute the transform schedule

In [None]:
mod = run_pipeline(
    module,
    Pipeline().transform_interpreter(
        entry_point="main", debug_payload_root_tag="payload"
    ),
)
print(mod)

# Lower to NVVM (and LLVM)

In [None]:
CUDA_RUNTIME_EXISTS = Path("/usr/local/cuda").exists()
if CUDA_RUNTIME_EXISTS:
    backend = LLVMJITBackend([CUDA_RUNTIME_LIB_PATH])
    # this doesn't actually anything (no pipeline) but does generate C API/wrappers
    compiled_module = backend.compile(
        find_ops(
            mod.operation,
            lambda x: "transform.target_tag" in x.attributes
                      and x.attributes["transform.target_tag"].value == "payload",
            single=True,
        ),
        Pipeline().add_pass(
            "gpu-lower-to-nvvm-pipeline",
            **{
                "cubin-chip": "sm_80",
                "cubin-features": "+ptx76",
                "cubin-format": "fatbin",
            },
        ),
    )
    print(compiled_module)

# Load and run

In [None]:
if CUDA_RUNTIME_EXISTS:
    backend.load(compiled_module).main_capi_wrapper()