In [1]:
!pip install -q mlir-python-bindings -f https://makslevental.github.io/wheels
!pip install -q git+https://github.com/makslevental/mlir-python-extras@$BRANCH

# Boilerplate

In [2]:
import numpy as np

import mlir.extras.types as T
from mlir.dialects import builtin
from mlir.dialects.transform import any_op_t
from mlir.dialects.transform.extras import named_sequence, apply_patterns
from mlir.extras.util import find_ops
from mlir.ir import StringAttr, UnitAttr

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir.extras.dialects.ext.memref
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
from mlir.dialects.bufferization import LayoutMapOption
from mlir.dialects.transform.vector import (
    VectorContractLowering,
    VectorMultiReductionLowering,
    VectorTransferSplit,
    VectorTransposeLowering,
)
from mlir.extras.dialects.ext import linalg
from mlir.extras.dialects.ext.func import func
from mlir.extras.dialects.ext.transform import (
    match,
    tile_to_scf_for,
    get_parent_op,
    transform_any_op_t,
)
from mlir.extras.dialects.ext import transform
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend


# Context

In [3]:
ctx = RAIIMLIRContext()
module = ExplicitlyManagedModule()

# Kernel

In [4]:
M, K, N = 2, 4, 6


@func
def matmul_tensors(
    A: T.tensor(M, K, T.f32()),
    B: T.tensor(K, N, T.f32()),
    C: T.tensor(M, N, T.f32()),
):
    return linalg.matmul(A, B, C)

@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
def payload():
    matmul_tensors.emit(force=True)

# Transform schedule (based on [transform-e2e.mlir](https://github.com/llvm/llvm-project/blob/375bd2201ce0d2c76cb47a02c87b8ca5ba8a3509/mlir/test/Dialect/LLVM/transform-e2e.mlir))

In [5]:
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
def mod_transform():
    @named_sequence("main", [any_op_t()], [])
    def main(module_op: any_op_t()):
        matmul = match(module_op, ops=["linalg.matmul"])
        tiled_matmul, (_, _, inner_loop) = tile_to_scf_for(matmul, sizes=[2, 2, 2])
        transform.structured.vectorize_children_and_apply_patterns(
            get_parent_op(transform_any_op_t(), tiled_matmul, isolated_from_above=True)
        )
        new_mod = transform.bufferization.one_shot_bufferize(
            module_op,
            function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,
            bufferize_function_boundaries=True,
        )

        func_op = match(new_mod, ops=["func.func"])

        @apply_patterns(func_op)
        def pats():
            transform.apply_patterns.vector.lower_contraction(
                lowering_strategy=VectorContractLowering.OuterProduct
            )
            transform.apply_patterns.vector.transfer_permutation_patterns()
            transform.apply_patterns.vector.lower_multi_reduction(
                lowering_strategy=VectorMultiReductionLowering.InnerParallel
            )
            transform.apply_patterns.vector.split_transfer_full_partial(
                split_transfer_strategy=VectorTransferSplit.LinalgCopy
            )
            transform.apply_patterns.vector.transfer_to_scf(
                max_transfer_rank=1, full_unroll=True
            )
            transform.apply_patterns.vector.lower_transfer(max_transfer_rank=1)
            transform.apply_patterns.vector.lower_shape_cast()
            transform.apply_patterns.vector.lower_transpose(
                lowering_strategy=VectorTransposeLowering.Shuffle1D
            )

# "Finish" the module

In [6]:
module = module.finish()
print(module)

module {
  module attributes {transform.target_tag = "payload"} {
    func.func @matmul_tensors(%arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) -> tensor<2x6xf32> {
      %0 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x4xf32>, tensor<4x6xf32>) outs(%arg2 : tensor<2x6xf32>) -> tensor<2x6xf32>
      return %0 : tensor<2x6xf32>
    }
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @main(%arg0: !transform.any_op) {
      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %tiled_linalg_op, %loops:3 = transform.structured.tile_using_for %0[2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
      %1 = transform.get_parent_op %tiled_linalg_op {isolated_from_above} : (!transform.any_op) -> !transform.any_op
      %2 = transform.structured.vectorize_children_and_apply_patterns 

# Vectorize (execute the transform schedule)

In [7]:
vectorized_module = run_pipeline(
    module,
    pipeline=Pipeline().transform_interpreter(
        entry_point="main", debug_payload_root_tag="payload"
    ),
)
print(vectorized_module)

#map = affine_map<(d0) -> (d0 + 1)>
module {
  module attributes {transform.target_tag = "payload"} {
    func.func @matmul_tensors(%arg0: memref<2x4xf32>, %arg1: memref<4x6xf32>, %arg2: memref<2x6xf32>) -> memref<2x6xf32> {
      %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
      %cst_0 = arith.constant dense<0.000000e+00> : vector<2x2xf32>
      %c4 = arith.constant 4 : index
      %c6 = arith.constant 6 : index
      %c0 = arith.constant 0 : index
      %c2 = arith.constant 2 : index
      %0 = scf.for %arg3 = %c0 to %c2 step %c2 iter_args(%arg4 = %arg2) -> (memref<2x6xf32>) {
        %1 = scf.for %arg5 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (memref<2x6xf32>) {
          %2 = scf.for %arg7 = %c0 to %c4 step %c2 iter_args(%arg8 = %arg6) -> (memref<2x6xf32>) {
            %3 = vector.load %arg0[%arg3, %arg7] : memref<2x4xf32>, vector<2xf32>
            %4 = affine.apply #map(%arg3)
            %5 = vector.load %arg0[%4, %arg7] : memref<2x4xf32>, vector<2xf32>
 

# Lower to CPU (through LLVM, based on [TestLowerToLLVM.cpp](https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44))

In [8]:
lower_to_llvm = (
    Pipeline()
    .Func(
        Pipeline()
        # Blanket-convert any remaining high-level vector ops to loops if any remain.
        .convert_vector_to_scf()
        # Blanket-convert any remaining linalg ops to loops if any remain.
        .convert_linalg_to_loops()
    )
    # Blanket-convert any remaining affine ops if any remain.
    .lower_affine()
    # Convert SCF to CF (always needed).
    .convert_scf_to_cf()
    # Sprinkle some cleanups.
    .canonicalize()
    .cse()
    # Convert vector to LLVM (always needed).
    .convert_vector_to_llvm()
    # Convert Math to LLVM (always needed).
    .Func(Pipeline().convert_math_to_llvm())
    # Expand complicated MemRef operations before lowering them.
    .expand_strided_metadata()
    # The expansion may create affine expressions. Get rid of them.
    .lower_affine()
    # Convert MemRef to LLVM (always needed).
    .finalize_memref_to_llvm()
    # Convert Func to LLVM (always needed).
    .convert_func_to_llvm()
    .convert_arith_to_llvm()
    .convert_cf_to_llvm()
    # Convert Index to LLVM (always needed).
    .convert_index_to_llvm()
    # Convert remaining unrealized_casts (always needed).
    .reconcile_unrealized_casts()
)

backend = LLVMJITBackend()
compiled_module = backend.compile(
    find_ops(
        vectorized_module.operation,
        lambda x: "transform.target_tag" in x.attributes
        and x.attributes["transform.target_tag"].value == "payload",
        single=True,
    ),
    kernel_name=matmul_tensors.__name__,
    pipeline=lower_to_llvm,
)
print(compiled_module)

module attributes {transform.target_tag = "payload"} {
  llvm.func @matmul_tensors(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
    %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
    %5 =

# Load, run, and compare against numpy

In [9]:
A = np.random.randint(0, 10, (M, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, N)).astype(np.float32)
C = np.zeros((M, N), dtype=np.float32)

backend.load(compiled_module).matmul_tensors_capi_wrapper(A, B, C)
assert np.allclose(A @ B, C)