# Welcome to `mlir-python-utils` enjoy your stay!

more at https://github.com/makslevental/mlir-python-utils

In [None]:
!pip install mlir-python-utils[mlir] -f https://makslevental.github.io/wheels/ &> /dev/null

In [None]:
!configure-mlir-python-utils -y mlir &> /dev/null

# "Boiler plate"

In [None]:
import numpy as np

import mlir_utils.types as T
from mlir_utils.ast.canonicalize import canonicalize
from mlir_utils.context import MLIRContext, mlir_mod_ctx
from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.memref import load, store, S
from mlir_utils.dialects.ext.func import func
from mlir_utils.dialects.ext.scf import canonicalizer as scf, range_ as range
from mlir_utils.runtime.passes import Pipeline, run_pipeline
from mlir_utils.runtime.refbackend import LLVMJITBackend

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir_utils.dialects.ext.memref

ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()
backend = LLVMJITBackend()

# MWE

In [None]:
K = 10
memref_i64 = T.memref(K, K, T.i64)

@func
@canonicalize(using=scf)
def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):
    one = constant(1)
    two = constant(2)
    if one > two:
        C[0, 0] = constant(3, T.i64)
    else:
        for i in range(0, K):
            for j in range(0, K):
                C[i, j] = A[i, j] * B[i, j]

## `func`, `memref`, `scf`, and `arith` dialects

In [None]:
memfoo.emit()
run_pipeline(ctx.module, Pipeline().cse())
print(ctx.module)

module {
  func.func @memfoo(%arg0: memref<10x10xi64>, %arg1: memref<10x10xi64>, %arg2: memref<10x10xi64>) {
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %0 = arith.cmpi ugt, %c1_i32, %c2_i32 : i32
    scf.if %0 {
      %c3_i64 = arith.constant 3 : i64
      %c0 = arith.constant 0 : index
      memref.store %c3_i64, %arg2[%c0, %c0] : memref<10x10xi64>
    } else {
      %c0 = arith.constant 0 : index
      %c10 = arith.constant 10 : index
      %c1 = arith.constant 1 : index
      scf.for %arg3 = %c0 to %c10 step %c1 {
        scf.for %arg4 = %c0 to %c10 step %c1 {
          %1 = memref.load %arg0[%arg3, %arg4] : memref<10x10xi64>
          %2 = memref.load %arg1[%arg3, %arg4] : memref<10x10xi64>
          %3 = arith.muli %1, %2 : i64
          memref.store %3, %arg2[%arg3, %arg4] : memref<10x10xi64>
        }
      }
    }
    return
  }
}



## Lower to `llvm` dialect

In [None]:
module = backend.compile(
    ctx.module,
    kernel_name=memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)
print(module)

module attributes {llvm.data_layout = ""} {
  llvm.func @memfoo(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: i64, %arg14: !llvm.ptr, %arg15: !llvm.ptr, %arg16: i64, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(3 : i64) : i64
    %1 = llvm.mlir.constant(0 : index) : i64
    %2 = llvm.mlir.constant(10 : index) : i64
    %3 = llvm.mlir.constant(1 : index) : i64
    %4 = llvm.mlir.constant(false) : i1
    llvm.cond_br %4, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %5 = llvm.mul %1, %2  : i64
    %6 = llvm.add %5, %1  : i64
    %7 = llvm.getelementptr %arg15[%6] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    llvm.store %0, %7 : i64, !llvm.ptr
    llvm.br ^bb9
  ^bb2:  // pred: ^bb0
    llvm.br ^bb3(%1 : i64)
  ^bb3(%8: i64):  // 2 preds: ^bb2, ^bb7
    %9 = llvm.icmp "slt" %8, %2 : i64
 

## Run

In [None]:
A = np.random.randint(0, 10, (K, K)).astype(np.int64)
B = np.random.randint(0, 10, (K, K)).astype(np.int64)
C = np.zeros((K, K), dtype=np.int64)
backend.load(module).memfoo(A, B, C)

## Check the results

In [None]:
print(C)
assert np.array_equal(A * B, C)

[[ 9  8  7 81  0  0 48  0 12 40]
 [ 5 16  0  0 27 32 48 64 30 15]
 [16  0 36  4  0 63 12  8 18 42]
 [ 9 24  2 18  4  0 21  0 12 36]
 [15 27 18 18  0  6 45 45 28 24]
 [30 49  3  9  9  6  0 72  0 15]
 [ 0 30 49 15  0  0 28 64 14 15]
 [ 8  8 63 64  0  0 10  4 28  5]
 [ 9  4  5  5 14  0  0 24  9 48]
 [21  5 45  0 30 54 49 49 10 36]]


## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);

# Slightly more complicated example

In [None]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

K = 256
D = 32

F = K // D
ranked_memref_kxk_f32 = T.memref(K, K, T.f32)
ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))

@func
@canonicalize(using=scf)
def tile(
    A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32
):
    for i in range(0, D):
        for j in range(0, D):
            C[i, j] = A[i, j] + B[i, j]

@func
@canonicalize(using=scf)
def tiled_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range(0, F):
        for j in range(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            tile(a, b, c)


## `func`, `memref`, `scf`, and `arith` dialects

In [None]:
tiled_memfoo.emit()
module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

module {
  func.func @tiled_memfoo(%arg0: memref<256x256xf32>, %arg1: memref<256x256xf32>, %arg2: memref<256x256xf32>) {
    %c0 = arith.constant 0 : index
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    scf.for %arg3 = %c0 to %c8 step %c1 {
      scf.for %arg4 = %c0 to %c8 step %c1 {
        %c32 = arith.constant 32 : index
        %0 = arith.muli %arg3, %c32 : index
        %1 = arith.addi %arg3, %c1 : index
        %2 = arith.muli %arg4, %c32 : index
        %3 = arith.addi %arg4, %c1 : index
        %subview = memref.subview %arg0[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_0 = memref.subview %arg1[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_1 = memref.subview %arg2[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        func.call @tile(%subview, %subview_0, %subview_1) 

## Run

In [None]:
module = backend.compile(
    module,
    kernel_name=tiled_memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)

A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).tiled_memfoo(A, B, C)

## Check your results

In [None]:
print(C)
assert np.array_equal(A + B, C)

[[ 0.  3. 14. ...  7.  6.  0.]
 [ 9.  9.  8. ... 15.  9. 13.]
 [ 9.  5.  1. ... 10.  6.  8.]
 ...
 [ 2. 17.  9. ...  8.  9. 17.]
 [10.  3.  7. ...  5. 10.  8.]
 [15. 11. 10. ... 11.  8.  9.]]


## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);

# Do it like the professionals

In [None]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

ranked_memref_kxk_f32 = T.memref(K, K, T.f32)
ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))

from mlir_utils.dialects import linalg

@func
@canonicalize(using=scf)
def linalg_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range(0, F):
        for j in range(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            linalg.add(a, b, c)

linalg_memfoo.emit()
module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

module {
  func.func @linalg_memfoo(%arg0: memref<256x256xf32>, %arg1: memref<256x256xf32>, %arg2: memref<256x256xf32>) {
    %c0 = arith.constant 0 : index
    %c8 = arith.constant 8 : index
    %c1 = arith.constant 1 : index
    scf.for %arg3 = %c0 to %c8 step %c1 {
      scf.for %arg4 = %c0 to %c8 step %c1 {
        %c32 = arith.constant 32 : index
        %0 = arith.muli %arg3, %c32 : index
        %1 = arith.addi %arg3, %c1 : index
        %2 = arith.muli %arg4, %c32 : index
        %3 = arith.addi %arg4, %c1 : index
        %subview = memref.subview %arg0[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_0 = memref.subview %arg1[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        %subview_1 = memref.subview %arg2[%0, %2] [32, 32] [1, 1] : memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
        linalg.add ins(%subview, %subview_0 : memref<32x3

## Run

In [None]:
module = backend.compile(
    module,
    kernel_name=linalg_memfoo.__name__,
    pipeline=Pipeline().convert_linalg_to_loops().bufferize().lower_to_llvm()
)
invoker = backend.load(module)
A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).linalg_memfoo(A, B, C)

## Check your results

In [None]:
print(C)
assert np.array_equal(A + B, C)

[[ 6.  4.  8. ...  7. 12.  3.]
 [10.  7.  8. ... 10. 11. 12.]
 [ 7.  7. 10. ...  6.  6. 11.]
 ...
 [ 4. 10.  7. ...  5.  5. 12.]
 [ 8. 12. 13. ... 15.  6. 18.]
 [ 9. 13. 12. ... 12. 12. 10.]]


## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);