# Based on [transform-mma-sync-matmul-f16-f16-accum.mlir](https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6)

In [11]:
!pip install -q mlir_python_bindings==19.0.0.2024020206+cuda.374a600d -f https://makslevental.github.io/wheels
!pip install -q git+https://github.com/makslevental/mlir-python-extras.git

# Boilerplate

In [12]:
from pathlib import Path

import mlir.extras.types as T
from mlir.dialects import builtin
from mlir.dialects.transform import any_op_t
from mlir.dialects.transform.extras import named_sequence
from mlir.dialects.transform.structured import MatchInterfaceEnum
from mlir.ir import StringAttr, UnitAttr

from mlir import _mlir_libs
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
from mlir.extras.dialects.ext import arith, memref, scf, gpu
from mlir.extras.dialects.ext import linalg
from mlir.extras.dialects.ext import transform
from mlir.extras.dialects.ext.func import func
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend
from mlir.extras.util import find_ops

CUDA_RUNTIME_LIB_PATH = Path(_mlir_libs.__file__).parent / f"libmlir_cuda_runtime.so"
assert CUDA_RUNTIME_LIB_PATH.exists()

# Context

In [13]:
ctx = RAIIMLIRContext()
module = ExplicitlyManagedModule()

# Kernel and helper code

In [14]:
range_ = scf.range_

M, K, N = 16, 16, 8

# forward reference...
# TODO(max): figure out closures...
printMemrefF32_ = []


@func
def compute_linspace_val(ridx: T.index(), cidx: T.index(), stride_cidx: T.index()):
    r = arith.index_cast(ridx, to=T.i32())
    c = arith.index_cast(cidx, to=T.i32())
    stride_c = arith.index_cast(stride_cidx, to=T.i32())
    v2 = r * stride_c
    v3 = c + v2
    v4 = arith.sitofp(T.f16(), v3)
    factor = arith.constant(64.0, T.f16())
    v5 = arith.divf(v4, factor)
    return v5


@func
@canonicalize(using=scf.canonicalizer)
def print_lhs_as_memref_32(lhs: T.memref(M, K, T.f16())):
    M = memref.dim(lhs, 0)
    K = memref.dim(lhs, 1)
    tmp_alloc = memref.alloc(M, K, T.f32())
    for m in range_(0, M):
        for k in range_(0, K):
            f16 = lhs[m, k]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[m, k] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def print_rhs_as_memref_32(rhs: T.memref(K, N, T.f16())):
    K = memref.dim(rhs, 0)
    N = memref.dim(rhs, 1)
    tmp_alloc = memref.alloc(K, N, T.f32())
    for k in range_(0, K):
        for n in range_(0, N):
            f16 = rhs[k, n]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[k, n] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def print_res_as_memref_32(res: T.memref(M, N, T.f16())):
    c0 = arith.constant(0, index=True)
    c1 = arith.constant(1, index=True)
    M = memref.dim(res, c0)
    N = memref.dim(res, c1)
    tmp_alloc = memref.alloc(M, N, T.f32())
    for m in range_(0, M):
        for n in range_(0, N):
            f16 = res[m, n]
            f32 = arith.extf(T.f32(), f16)
            tmp_alloc[m, n] = f32

    casted = memref.cast(T.memref(T.f32()), tmp_alloc)
    printMemrefF32_[0](casted)
    memref.dealloc(tmp_alloc)


@func
@canonicalize(using=scf.canonicalizer)
def main():
    lhs = memref.alloc(M, K, T.f16())
    rhs = memref.alloc(K, N, T.f16())
    res = memref.alloc(M, N, T.f16())

    M_ = memref.dim(res, 0)
    N_ = memref.dim(res, 1)
    K_ = memref.dim(lhs, 1)

    _f1 = arith.constant(1.0e00, T.f16())
    _f0 = arith.constant(0.0e00, T.f16())
    _c32 = arith.constant(32, T.index())

    # Initialize the lhs matrix with a linspace function.
    for r in range_(0, M_):
        for c in range_(0, K_):
            idx = compute_linspace_val(r, c, K_)
            lhs[r, c] = idx

    # Initialize the rhs matrix with a linspace function.
    for r in range_(0, K_):
        for c in range_(0, N_):
            idx = compute_linspace_val(r, c, N_)
            rhs[r, c] = idx

    # Initialize the res matrix with a linspace function.
    for r in range_(0, M_):
        for c in range_(0, N_):
            idx = compute_linspace_val(r, c, N_)
            res[r, c] = idx

    ulhs = memref.cast(T.memref(T.f16()), lhs)
    urhs = memref.cast(T.memref(T.f16()), rhs)
    ures = memref.cast(T.memref(T.f16()), res)
    gpu.host_register(ulhs)
    gpu.host_register(urhs)
    gpu.host_register(ures)

    print_lhs_as_memref_32(lhs)
    print_rhs_as_memref_32(rhs)

    @gpu.launch(grid_size=[1, 1, 1], block_size=[32, 1, 1])
    def kernel(bx, by, bz, tx, ty, tz, *grid_block_sizes):
        linalg.matmul(lhs, rhs, res)

    print_res_as_memref_32(res)


@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
def payload():
    compute_linspace_val.emit()

    @func
    def printMemrefF32(x: T.memref(T.f32())):
        ...

    printMemrefF32_.append(printMemrefF32)

    print_lhs_as_memref_32.emit()
    print_rhs_as_memref_32.emit()
    print_res_as_memref_32.emit()
    main.emit()

# Transform schedule


In [15]:
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
def mod_transform():
    @named_sequence(
        "main", [any_op_t()], [], arg_attrs=[{"transform.readonly": UnitAttr.get()}]
    )
    def main(module: any_op_t()):
        matmul = transform.match(module, ["linalg.matmul"])
        transform.nvgpu.rewrite_matmul_as_mma_sync(matmul)
        # clean up to simplify test below...
        all_loops = transform.match(
            module, interface=MatchInterfaceEnum.LoopLikeInterface
        )
        transform.apply_licm(all_loops)
        transform.apply_cse(module)

# "Finish" the module

In [16]:
module = module.finish()
print(module)

module {
  module attributes {transform.target_tag = "payload"} {
    func.func @compute_linspace_val(%arg0: index, %arg1: index, %arg2: index) -> f16 {
      %0 = arith.index_cast %arg0 : index to i32
      %1 = arith.index_cast %arg1 : index to i32
      %2 = arith.index_cast %arg2 : index to i32
      %3 = arith.muli %0, %2 : i32
      %4 = arith.addi %1, %3 : i32
      %5 = arith.sitofp %4 : i32 to f16
      %cst = arith.constant 6.400000e+01 : f16
      %6 = arith.divf %5, %cst : f16
      return %6 : f16
    }
    func.func private @printMemrefF32(memref<*xf32>)
    func.func @print_lhs_as_memref_32(%arg0: memref<16x16xf16>) {
      %c0 = arith.constant 0 : index
      %dim = memref.dim %arg0, %c0 : memref<16x16xf16>
      %c1 = arith.constant 1 : index
      %dim_0 = memref.dim %arg0, %c1 : memref<16x16xf16>
      %alloc = memref.alloc(%dim, %dim_0) : memref<?x?xf32>
      %c0_1 = arith.constant 0 : index
      %c1_2 = arith.constant 1 : index
      scf.for %arg1 = %c0_1 to %dim

# Execute the transform schedule

In [17]:
mod = run_pipeline(
    module,
    Pipeline().transform_interpreter(
        entry_point="main", debug_payload_root_tag="payload"
    ),
)
print(mod)

#map = affine_map<(d0) -> (d0 floordiv 4)>
#map1 = affine_map<(d0) -> (d0 * 2 - (d0 floordiv 4) * 8)>
#map2 = affine_map<(d0) -> (d0 * 2 - (d0 floordiv 4) * 8 + 1)>
#map3 = affine_map<(d0) -> (d0 floordiv 4 + 8)>
#map4 = affine_map<(d0) -> (d0 * 2 - (d0 floordiv 4) * 8 + 8)>
#map5 = affine_map<(d0) -> (d0 * 2 - (d0 floordiv 4) * 8 + 9)>
module {
  module attributes {transform.target_tag = "payload"} {
    func.func @compute_linspace_val(%arg0: index, %arg1: index, %arg2: index) -> f16 {
      %0 = arith.index_cast %arg0 : index to i32
      %1 = arith.index_cast %arg1 : index to i32
      %2 = arith.index_cast %arg2 : index to i32
      %3 = arith.muli %0, %2 : i32
      %4 = arith.addi %1, %3 : i32
      %5 = arith.sitofp %4 : i32 to f16
      %cst = arith.constant 6.400000e+01 : f16
      %6 = arith.divf %5, %cst : f16
      return %6 : f16
    }
    func.func private @printMemrefF32(memref<*xf32>)
    func.func @print_lhs_as_memref_32(%arg0: memref<16x16xf16>) {
      %c0 = arith.co

# Lower to NVVM (and LLVM)

In [18]:
backend = LLVMJITBackend([CUDA_RUNTIME_LIB_PATH])
# this doesn't actually anything (no pipeline) but does generate C API/wrappers
compiled_module = backend.compile(
    find_ops(
        mod.operation,
        lambda x: "transform.target_tag" in x.attributes
                  and x.attributes["transform.target_tag"].value == "payload",
        single=True,
    ),
    Pipeline().add_pass(
        "gpu-lower-to-nvvm-pipeline",
        **{
            "cubin-chip": "sm_80",
            "cubin-features": "+ptx76",
            "cubin-format": "fatbin",
        },
    ),
)
print(compiled_module)

module attributes {gpu.container_module, transform.target_tag = "payload"} {
  llvm.func @free(!llvm.ptr)
  llvm.func @malloc(i64) -> !llvm.ptr
  llvm.func @compute_linspace_val(%arg0: i64, %arg1: i64, %arg2: i64) -> f16 {
    %0 = llvm.mlir.constant(6.400000e+01 : f16) : f16
    %1 = llvm.trunc %arg0 : i64 to i32
    %2 = llvm.trunc %arg1 : i64 to i32
    %3 = llvm.trunc %arg2 : i64 to i32
    %4 = llvm.mul %1, %3  : i32
    %5 = llvm.add %2, %4  : i32
    %6 = llvm.sitofp %5 : i32 to f16
    %7 = llvm.fdiv %6, %0  : f16
    llvm.return %7 : f16
  }
  llvm.func @printMemrefF32(i64, !llvm.ptr) attributes {sym_visibility = "private"}
  llvm.func @print_lhs_as_memref_32(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
    %0 = llvm.mlir.constant(2 : index) : i64
    %1 = llvm.mlir.constant(16 : index) : i64
    %2 = llvm.mlir.constant(0 : index) : i64
    %3 = llvm.mlir.constant(1 : index) : i64
    %4 = llvm.mul %1, %1  : i64
    %5 = llv

# Load and run

In [19]:
backend.load(compiled_module).main_capi_wrapper()

Unranked Memref base@ = 0x560303453d50 rank = 2 offset = 0 sizes = [16, 16] strides = [16, 1] data = 
[[0,   0.015625,   0.03125,   0.046875,   0.0625,   0.078125,   0.09375,   0.109375,   0.125,   0.140625,   0.15625,   0.171875,   0.1875,   0.203125,   0.21875,   0.234375], 
 [0.25,   0.265625,   0.28125,   0.296875,   0.3125,   0.328125,   0.34375,   0.359375,   0.375,   0.390625,   0.40625,   0.421875,   0.4375,   0.453125,   0.46875,   0.484375], 
 [0.5,   0.515625,   0.53125,   0.546875,   0.5625,   0.578125,   0.59375,   0.609375,   0.625,   0.640625,   0.65625,   0.671875,   0.6875,   0.703125,   0.71875,   0.734375], 
 [0.75,   0.765625,   0.78125,   0.796875,   0.8125,   0.828125,   0.84375,   0.859375,   0.875,   0.890625,   0.90625,   0.921875,   0.9375,   0.953125,   0.96875,   0.984375], 
 [1,   1.01562,   1.03125,   1.04688,   1.0625,   1.07812,   1.09375,   1.10938,   1.125,   1.14062,   1.15625,   1.17188,   1.1875,   1.20312,   1.21875,   1.23438], 
 [1.25,   1.26562,