# Welcome to `mlir-python-extras` enjoy your stay!

more at https://github.com/makslevental/mlir-python-extras

In [73]:
!pip install mlir-python-bindings -f https://makslevental.github.io/wheels &> /dev/null
!pip install git+https://github.com/makslevental/mlir-python-extras@$BRANCH &> /dev/null

# "Boiler plate"

In [74]:
import numpy as np

import mlir.extras.types as T
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import mlir_mod_ctx
from mlir.extras.dialects.ext.arith import constant
from mlir.extras.dialects.ext.memref import S
from mlir.extras.dialects.ext.func import func
from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir.extras.dialects.ext.memref

ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()
backend = LLVMJITBackend()

# MWE

In [75]:
K = 10
memref_i64 = T.memref(K, K, T.i64)

@func(emit=True)
@canonicalize(using=scf)
def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):
    one = constant(1)
    two = constant(2)
    if one > two:
        C[0, 0] = constant(3, T.i64)
    else:
        for i in range(0, K):
            for j in range(0, K):
                C[i, j] = A[i, j] * B[i, j]

## `func`, `memref`, `scf`, and `arith` dialects

In [76]:
run_pipeline(ctx.module, Pipeline().cse())
print(ctx.module)

module {
  func.func @memfoo(%arg0: memref<10x10xi64>, %arg1: memref<10x10xi64>, %arg2: memref<10x10xi64>) {
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %0 = arith.cmpi ugt, %c1_i32, %c2_i32 : i32
    scf.if %0 {
      %c3_i64 = arith.constant 3 : i64
      %c0 = arith.constant 0 : index
      %c0_0 = arith.constant 0 : index
      memref.store %c3_i64, %arg2[%c0, %c0_0] : memref<10x10xi64>
    } else {
      %c0 = arith.constant 0 : index
      %c10 = arith.constant 10 : index
      %c1 = arith.constant 1 : index
      scf.for %arg3 = %c0 to %c10 step %c1 {
        %c0_0 = arith.constant 0 : index
        %c10_1 = arith.constant 10 : index
        %c1_2 = arith.constant 1 : index
        scf.for %arg4 = %c0_0 to %c10_1 step %c1_2 {
          %1 = memref.load %arg0[%arg3, %arg4] : memref<10x10xi64>
          %2 = memref.load %arg1[%arg3, %arg4] : memref<10x10xi64>
          %3 = arith.muli %1, %2 : i64
          memref.store %3, %arg2[%arg3, %arg4] : m

## Lower to `llvm` dialect

In [77]:
module = backend.compile(
    ctx.module,
    kernel_name=memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)
print(module)

module {
  llvm.func @memfoo(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(3 : i64) : i64
    %1 = llvm.mlir.constant(0 : index) : i64
    %2 = llvm.mlir.constant(10 : index) : i64
    %3 = llvm.mlir.constant(1 : index) : i64
    %4 = llvm.mlir.constant(false) : i1
    llvm.cond_br %4, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %5 = llvm.mul %1, %2  : i64
    %6 = llvm.add %5, %1  : i64
    %7 = llvm.getelementptr %arg15[%6] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    llvm.store %0, %7 : i64, !llvm.ptr
    llvm.br ^bb9
  ^bb2:  // pred: ^bb0
    llvm.br ^bb3(%1 : i64)
  ^bb3(%8: i64):  // 2 preds: ^bb2, ^bb7
    %9 = llvm.icmp "slt" %8, %2 : i64
    llvm.cond_br %9, ^bb4, ^bb8
  ^b

## Run

In [78]:
A = np.random.randint(0, 10, (K, K)).astype(np.int64)
B = np.random.randint(0, 10, (K, K)).astype(np.int64)
C = np.zeros((K, K), dtype=np.int64)
backend.load(module).memfoo(A, B, C)

## Check the results

In [79]:
print(C)
assert np.array_equal(A * B, C)

[[35  0 24  0  0  0  2 63 48 15]
 [28 56 18  8  0  6 49 10 16  6]
 [16 16 24 18 54  0 42 24  0  8]
 [ 6  0 16 27 24  2 18 48  0 72]
 [ 4 27 28  5 16 42 27 63  6 35]
 [ 0 72  6 20 24 30 56 18 14  0]
 [ 6  3  0 30 32  0 21  8 27  0]
 [25 27 35 21 12  1  0  0 32 12]
 [ 5 30  9 27 18  0  4  8 12 54]
 [ 0  0  5 42  8 48 24  0 36  7]]


## Clean up after yourself

In [80]:
ctx_man.__exit__(None, None, None);

# Slightly more complicated example

In [81]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

K = 256
D = 32

F = K // D
ranked_memref_kxk_f32 = T.memref(K, K, T.f32)
ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))

@func(emit=True)
@canonicalize(using=scf)
def tile(
    A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32
):
    for i in range(0, D):
        for j in range(0, D):
            C[i, j] = A[i, j] + B[i, j]

@func(emit=True)
@canonicalize(using=scf)
def tiled_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range(0, F):
        for j in range(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            tile(a, b, c)

## `func`, `memref`, `scf`, and `arith` dialects

In [82]:
print(ctx.module)
module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

module {
  func.func @tile(%arg0: memref<32x32xf32, strided<[256, 1], offset: ?>>, %arg1: memref<32x32xf32, strided<[256, 1], offset: ?>>, %arg2: memref<32x32xf32, strided<[256, 1], offset: ?>>) {
    %c0 = arith.constant 0 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    scf.for %arg3 = %c0 to %c32 step %c1 {
      %c0_0 = arith.constant 0 : index
      %c32_1 = arith.constant 32 : index
      %c1_2 = arith.constant 1 : index
      scf.for %arg4 = %c0_0 to %c32_1 step %c1_2 {
        %0 = memref.load %arg0[%arg3, %arg4] : memref<32x32xf32, strided<[256, 1], offset: ?>>
        %1 = memref.load %arg1[%arg3, %arg4] : memref<32x32xf32, strided<[256, 1], offset: ?>>
        %2 = arith.addf %0, %1 : f32
        memref.store %2, %arg2[%arg3, %arg4] : memref<32x32xf32, strided<[256, 1], offset: ?>>
      }
    }
    return
  }
  func.func @tiled_memfoo(%arg0: memref<256x256xf32>, %arg1: memref<256x256xf32>, %arg2: memref<256x256xf32>) {
    %c0 = arith.cons

## Run

In [83]:
module = backend.compile(
    module,
    kernel_name=tiled_memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)

A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).tiled_memfoo(A, B, C)

## Check your results

In [84]:
print(C)
assert np.array_equal(A + B, C)

[[ 0. 15. 16. ... 12.  8. 16.]
 [ 6. 14.  4. ...  8. 14. 11.]
 [12. 14.  3. ...  9.  9.  9.]
 ...
 [ 6.  3.  8. ...  5.  4.  9.]
 [14. 11. 12. ...  2.  8.  9.]
 [ 8. 16.  2. ...  1.  9.  9.]]


## Clean up after yourself

In [85]:
ctx_man.__exit__(None, None, None);

# Do it like the professionals

In [86]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

ranked_memref_kxk_f32 = T.memref(K, K, T.f32)
ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))

from mlir.extras.dialects.ext import linalg

@func(emit=True)
@canonicalize(using=scf)
def linalg_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range(0, F):
        for j in range(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            linalg.add(a, b, c)

module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

module {
  func.func @linalg_memfoo(%arg0: memref<256x256xf32>, %arg1: memref<256x256xf32>, %arg2: memref<256x256xf32>) {
    %c0 = arith.constant 0 : index
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    scf.for %arg3 = %c0 to %c8 step %c1 {
      scf.for %arg4 = %c0 to %c8 step %c1 {
        %c32 = arith.constant 32 : index
        %0 = arith.muli %arg3, %c32 : index
        %1 = arith.addi %arg3, %c1 : index
        %2 = arith.muli %arg4, %c32 : index
        %3 = arith.addi %arg4, %c1 : index
        %subview = memref.subview %arg0[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_0 = memref.subview %arg1[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_1 = memref.subview %arg2[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        linalg.add ins(%subview, %subview_0 : memref<32x3

## Run

In [87]:
module = backend.compile(
    module,
    kernel_name=linalg_memfoo.__name__,
    pipeline=Pipeline().convert_linalg_to_loops().bufferize().lower_to_llvm()
)
invoker = backend.load(module)
A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).linalg_memfoo(A, B, C)

## Check your results

In [88]:
print(C)
assert np.array_equal(A + B, C)

[[12.  4. 13. ...  9.  2. 11.]
 [ 9. 15. 16. ... 11.  3.  7.]
 [12. 11. 12. ... 11.  2.  8.]
 ...
 [12.  3.  6. ...  7. 11. 15.]
 [ 9. 11.  7. ... 17. 11.  7.]
 [10.  6. 12. ... 10. 17.  9.]]


## Clean up after yourself

In [89]:
ctx_man.__exit__(None, None, None);