# Welcome to `mlir-python-extras` enjoy your stay!

more at https://github.com/makslevental/mlir-python-extras

In [3]:
# if running by yourself, you can use this instead
# !pip install -q mlir-python-extras -f https://makslevental.github.io/wheels
import os
BRANCH = os.getenv("BRANCH", "main")
os.environ["BRANCH"] = BRANCH
os.environ["SCRIPT_ADDRESS"] = f"https://raw.githubusercontent.com/makslevental/mlir-python-extras/refs/heads/{BRANCH}/scripts/get_latest_bindings.py"

In [12]:
%%bash
curl $SCRIPT_ADDRESS -o get_latest_bindings.py
latest_version=$(python get_latest_bindings.py "none")
pip install mlir_python_bindings==$latest_version -f https://makslevental.github.io/wheels
pip install git+https://github.com/makslevental/mlir-python-extras@$BRANCH

# "Boiler plate"

In [None]:
import numpy as np

import mlir.extras.types as T
from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.context import mlir_mod_ctx
from mlir.extras.dialects.ext.arith import constant
from mlir.extras.dialects.ext.memref import S
from mlir.extras.dialects.ext.func import func
from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_
from mlir.extras.runtime.passes import Pipeline, run_pipeline
from mlir.extras.runtime.refbackend import LLVMJITBackend
from mlir.ir import StridedLayoutAttr

# you need this to register the memref value caster
# noinspection PyUnresolvedReferences
import mlir.extras.dialects.ext.memref

ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()
backend = LLVMJITBackend()

# MWE

In [None]:
K = 10
memref_i64 = T.memref(K, K, T.i64())

@func(emit=True)
@canonicalize(using=scf)
def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):
    one = constant(1)
    two = constant(2)
    if one > two:
        C[0, 0] = constant(3, T.i64())
    else:
        for i in range_(0, K):
            for j in range_(0, K):
                C[i, j] = A[i, j] * B[i, j]

## `func`, `memref`, `scf`, and `arith` dialects

In [None]:
run_pipeline(ctx.module, Pipeline().cse())
print(ctx.module)

## Lower to `llvm` dialect

In [None]:
module = backend.compile(
    ctx.module,
    kernel_name=memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)
print(module)

## Run

In [None]:
A = np.random.randint(0, 10, (K, K)).astype(np.int64)
B = np.random.randint(0, 10, (K, K)).astype(np.int64)
C = np.zeros((K, K), dtype=np.int64)
backend.load(module).memfoo(A, B, C)

## Check the results

In [None]:
print(C)
assert np.array_equal(A * B, C)

## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);

# Slightly more complicated example

In [None]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

K = 256
D = 32

F = K // D
ranked_memref_kxk_f32 = T.memref(K, K, T.f32())
layout = StridedLayoutAttr.get(S, (K, 1))
ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)

@func(emit=True)
@canonicalize(using=scf)
def tile(
    A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32
):
    for i in range_(0, D):
        for j in range_(0, D):
            C[i, j] = A[i, j] + B[i, j]

@func(emit=True)
@canonicalize(using=scf)
def tiled_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range_(0, F):
        for j in range_(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            tile(a, b, c)

## `func`, `memref`, `scf`, and `arith` dialects

In [None]:
print(ctx.module)
module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

## Run

In [None]:
module = backend.compile(
    module,
    kernel_name=tiled_memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)

A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).tiled_memfoo(A, B, C)

## Check your results

In [None]:
print(C)
assert np.array_equal(A + B, C)

## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);

# Do it like the professionals

In [None]:
ctx_man = mlir_mod_ctx()
ctx = ctx_man.__enter__()

ranked_memref_kxk_f32 = T.memref(K, K, T.f32())
layout = StridedLayoutAttr.get(S, (K, 1))
ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)

from mlir.extras.dialects.ext import linalg

@func(emit=True)
@canonicalize(using=scf)
def linalg_memfoo(
    A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32
):
    for i in range_(0, F):
        for j in range_(0, F):
            l = lambda l: l * D
            r = lambda r: (r + 1) * D
            a, b, c = (
                A[l(i) : r(i), l(j) : r(j)],
                B[l(i) : r(i), l(j) : r(j)],
                C[l(i) : r(i), l(j) : r(j)],
            )
            linalg.add(a, b, c)

module = run_pipeline(ctx.module, str(Pipeline().cse()))
print(module)

## Run

In [None]:
module = backend.compile(
    module,
    kernel_name=linalg_memfoo.__name__,
    pipeline=Pipeline().convert_linalg_to_loops().bufferize().lower_to_llvm()
)
invoker = backend.load(module)
A = np.random.randint(0, 10, (K, K)).astype(np.float32)
B = np.random.randint(0, 10, (K, K)).astype(np.float32)
C = np.zeros((K, K)).astype(np.float32)

backend.load(module).linalg_memfoo(A, B, C)

## Check your results

In [None]:
print(C)
assert np.array_equal(A + B, C)

## Clean up after yourself

In [None]:
ctx_man.__exit__(None, None, None);