### Installation

conda install -c jim22k mlir-bindings

In [1]:
# This is only needed if the bindings are built locally
#import sys
#sys.path.append("/Users/jkitchen/Projects/llvm/llvm-project/build/tools/mlir/python_packages/mlir_core")

# This is not needed
# SHARED_LIB = "/Users/jkitchen/Projects/llvm/llvm-project/build/lib/libmlir_c_runner_utils.dylib"

In [2]:
import ctypes
import numbers
import mlir
import numpy as np
from mlir import ir
from mlir import passmanager
from mlir import execution_engine
from mlir import runtime
from mlir import dialects

from mlir.dialects import arith
from mlir.dialects import bufferization
from mlir.dialects import func
from mlir.dialects import linalg
from mlir.dialects import sparse_tensor
from mlir.dialects import tensor
from mlir.dialects import scf

### Add hardcoded ints

In [3]:
with ir.Context(), ir.Location.unknown():
    module = ir.Module.create()
    with ir.InsertionPoint(module.body):
        i32 = ir.IntegerType.get_signless(32)
        @func.FuncOp.from_py_func()
        def main():
            one = arith.ConstantOp(i32, 1)
            two = arith.ConstantOp(i32, 2)
            total = arith.AddIOp(one, two)
            return total
        main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
    # It is okay to manually specify passes
    pm = passmanager.PassManager.parse("convert-arith-to-llvm,convert-func-to-llvm")
print(module)
pm = pm.run(module)
engine = execution_engine.ExecutionEngine(module)
arg_pointers = [
    ctypes.pointer(ctypes.c_long(0)),
]
engine.invoke("main", *arg_pointers)
print('-'*30)
print(f"result = {arg_pointers[0].contents.value}")

module {
  func.func @main() -> i32 attributes {llvm.emit_c_interface} {
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %0 = arith.addi %c1_i32, %c2_i32 : i32
    return %0 : i32
  }
}

------------------------------
result = 3


**Notes:**

Values (like `total`) have:

- total.attributes
- total.regions
- total.operands
- total.results

### Multiply input floats

In [4]:
with ir.Context(), ir.Location.unknown():
    module = ir.Module.create()
    with ir.InsertionPoint(module.body):
        f64 = ir.F64Type.get()
        @func.FuncOp.from_py_func(f64, f64)
        def main(x, y):
            product = arith.MulFOp(x, y)
            return product
        main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
    pm = passmanager.PassManager.parse("sparse-compiler")  # pre-built pass pipeline
print(module)
pm = pm.run(module)
engine = execution_engine.ExecutionEngine(module)

def mul2(x, y):
    arg_pointers = [
        ctypes.pointer(ctypes.c_double(x)),
        ctypes.pointer(ctypes.c_double(y)),
        ctypes.pointer(ctypes.c_double(0)),  # result usually goes at the end
    ]
    engine.invoke("main", *arg_pointers)
    return arg_pointers[-1].contents.value

result = mul2(-1.5, 21.25)
print('-'*30)
print(f"{result=}")

module {
  func.func @main(%arg0: f64, %arg1: f64) -> f64 attributes {llvm.emit_c_interface} {
    %0 = arith.mulf %arg0, %arg1 : f64
    return %0 : f64
  }
}

------------------------------
result=-31.875


In [5]:
mul2(1.1, 2.2)

2.4200000000000004

### How to construct a sparse tensor type

**Note**: this doesn't compute anything

In [6]:
with ir.Context(), ir.Location.unknown():
    module = ir.Module.create()
    with ir.InsertionPoint(module.body):
        i32 = ir.IntegerType.get_signless(32)
        sp_encoding = sparse_tensor.EncodingAttr.get(
            [sparse_tensor.DimLevelType.dense, sparse_tensor.DimLevelType.compressed],
            ir.AffineMap.get_permutation([0, 1]),
            0,
            0
        )
        rtt = ir.RankedTensorType.get((-1, -1), i32, sp_encoding)
        @func.FuncOp.from_py_func(rtt)
        def main(x):
            return x
        main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
    pm = passmanager.PassManager.parse("sparse-compiler")
print(module)

module {
  func.func @main(%arg0: tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> attributes {llvm.emit_c_interface} {
    return %arg0 : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
  }
}



# Full Example of generic dispatching

### Sum a passed in tensor using linalg.generic

In [7]:
NP_TYPE_TO_MLIR = {
    np.dtype(np.int8).name: lambda: ir.IntegerType.get_signless(8),
    np.dtype(np.int16).name: lambda: ir.IntegerType.get_signless(16),
    np.dtype(np.int32).name: lambda: ir.IntegerType.get_signless(32),
    np.dtype(np.int64).name: lambda: ir.IntegerType.get_signless(64),
    np.dtype(np.float32).name: lambda: ir.F32Type.get(),
    np.dtype(np.float64).name: lambda: ir.F64Type.get(),
}
NP_TYPE_TO_CTYPE = {
    np.dtype(np.int8).name: ctypes.c_int8,
    np.dtype(np.int16).name: ctypes.c_int16,
    np.dtype(np.int32).name: ctypes.c_int32,
    np.dtype(np.int64).name: ctypes.c_int64,
    np.dtype(np.float32).name: ctypes.c_float,
    np.dtype(np.float64).name: ctypes.c_double,
}

memoized = {}

def reduce_sum(arr):
    key = (len(arr), arr.dtype.name)
    if key not in memoized:
        memoized[key] = _build_reduce_sum(arr)
    engine = memoized[key]
    
    c_typ = NP_TYPE_TO_CTYPE[arr.dtype.name]
    out = runtime.make_nd_memref_descriptor(1, c_typ)()
    
    arg_pointers = [
        ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(arr))),
        ctypes.pointer(c_typ(0)),
    ]
    engine.invoke("main", *arg_pointers)
    result = arg_pointers[-1].contents.value
    return result

def _build_reduce_sum(arr):
    with ir.Context(), ir.Location.unknown():
        module = ir.Module.create()
        with ir.InsertionPoint(module.body):
            dtype = NP_TYPE_TO_MLIR[arr.dtype.name]()
            type_a = ir.RankedTensorType.get([len(arr)], dtype)
            type_out = ir.RankedTensorType.get([], dtype)
            @func.FuncOp.from_py_func(type_a)
            def main(x):
                vv = bufferization.AllocTensorOp(type_out, [], None, None)
                generic_op = linalg.GenericOp(
                    [type_out],
                    [x],
                    [vv],
                    ir.ArrayAttr.get([ir.AffineMapAttr.get(ir.AffineMap.get_permutation([0])),
                                      ir.AffineMapAttr.get(ir.AffineMap.get(1, 0, []))]),
                    ir.ArrayAttr.get([ir.StringAttr.get("reduction")]),
                )
                # Construct the linalg.generic body
                block = generic_op.regions[0].blocks.append(dtype, dtype)
                with ir.InsertionPoint(block):
                    a, b = block.arguments
                    if issubclass(arr.dtype.type, numbers.Integral):
                        res = arith.AddIOp(a, b)
                    else:
                        res = arith.AddFOp(a, b)
                    linalg.YieldOp([res])
                final_result = tensor.ExtractOp(dtype, generic_op.result, [])
                return final_result
            main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
        pm = passmanager.PassManager.parse("sparse-compiler")
    #print(module)
    pm = pm.run(module)
    engine = execution_engine.ExecutionEngine(module)
    return engine

In [8]:
reduce_sum(np.array([1., 2., 3., 4.], dtype=np.float32))

10.0

In [9]:
reduce_sum(np.arange(50))

1225

In [10]:
k = np.arange(50000, dtype=np.float64)
reduce_sum(k)

1249975000.0

In [11]:
k.sum()

1249975000.0

In [12]:
memoized.keys()

dict_keys([(4, 'float32'), (50, 'int64'), (50000, 'float64')])

### Add 1 to a passed in numpy array

In [13]:
with ir.Context(), ir.Location.unknown():
    module = ir.Module.create()
    with ir.InsertionPoint(module.body):
        f32 = ir.F32Type.get()
        type_a = ir.RankedTensorType.get([4], f32)
        @func.FuncOp.from_py_func(type_a)
        def main(arr):
            one = arith.ConstantOp(f32, 1.0)
            vv = tensor.SplatOp(type_a, one)
            generic_op = linalg.GenericOp(
                [type_a],
                [arr],
                [vv],
                ir.ArrayAttr.get([ir.AffineMapAttr.get(ir.AffineMap.get_permutation([0])),
                                  ir.AffineMapAttr.get(ir.AffineMap.get_permutation([0]))]),
                ir.ArrayAttr.get([ir.StringAttr.get("parallel")]),
            )
            # Construct the linalg.generic body
            block = generic_op.regions[0].blocks.append(f32, f32)
            with ir.InsertionPoint(block):
                a, b = block.arguments
                res = arith.AddFOp(a, b)
                linalg.YieldOp([res])
            return generic_op.result
        main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
    pm = passmanager.PassManager.parse("sparse-compiler")
print(module)
pm = pm.run(module)
print('-'*50)
print(module)
engine = execution_engine.ExecutionEngine(module)

#map = affine_map<(d0) -> (d0)>
module {
  func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> attributes {llvm.emit_c_interface} {
    %cst = arith.constant 1.000000e+00 : f32
    %0 = tensor.splat %cst : tensor<4xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%0 : tensor<4xf32>) {
    ^bb0(%arg1: f32, %arg2: f32):
      %2 = arith.addf %arg1, %arg2 : f32
      linalg.yield %2 : f32
    } -> tensor<4xf32>
    return %1 : tensor<4xf32>
  }
}

--------------------------------------------------
module attributes {llvm.data_layout = ""} {
  llvm.func @malloc(i64) -> !llvm.ptr<i8>
  llvm.mlir.global private constant @__constant_4xf32(dense<1.000000e+00> : tensor<4xf32>) {addr_space = 0 : i32, alignment = 128 : i64} : !llvm.array<4 x f32>
  llvm.func @main(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64, %arg3: i64, %arg4: i64) -> !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> at

In [14]:
a = np.array([1., 2., 3., 4.], dtype=np.float32)
out = runtime.make_nd_memref_descriptor(1, ctypes.c_float)()

arg_pointers = [
    ctypes.pointer(ctypes.pointer(out)),  # result goes first for some strange reason?!?
    ctypes.pointer(ctypes.pointer(runtime.get_ranked_memref_descriptor(a))),
]
engine.invoke("main", *arg_pointers)
result = runtime.ranked_memref_to_numpy(arg_pointers[0][0])
result

array([2., 3., 4., 5.], dtype=float32)

### SCF Example

In [15]:
with ir.Context(), ir.Location.unknown():
    module = ir.Module.create()
    with ir.InsertionPoint(module.body):
        i1 = ir.IntegerType.get_signless(1)
        i64 = ir.IntegerType.get_signless(64)
        f64 = ir.F64Type.get()
        @func.FuncOp.from_py_func(f64)
        def main(x):
            zero = arith.ConstantOp(f64, 0.0)
            # TODO: Figure out why this isn't allowed
            # cmp = arith.CmpFOp(i1, ir.StringAttr.get("ogt"), x, zero)
            cmp = arith.CmpFOp(i1, ir.IntegerAttr.get(i64, 2), x, zero)
            if_ = scf.IfOp(cmp.result, [f64], hasElse=True)
            with ir.InsertionPoint(if_.then_block):
                negX = arith.NegFOp(x)
                scf.YieldOp([negX])
            with ir.InsertionPoint(if_.else_block):
                scf.YieldOp([x])
            return if_.result
        main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
    pm = passmanager.PassManager.parse("sparse-compiler")
print(module)
pm = pm.run(module)
engine = execution_engine.ExecutionEngine(module)

def negabs(x):
    arg_pointers = [
        ctypes.pointer(ctypes.c_double(x)),
        ctypes.pointer(ctypes.c_double(0)),
    ]
    engine.invoke("main", *arg_pointers)
    return arg_pointers[-1].contents.value

print('Convert everything to negative')
print('-'*30)
for n in [2.3, -1.5, 0.0, -5.6, 5.9]:
    print(f"{' ' if n >= 0 else ''}{n} -> {negabs(n)}")

module {
  func.func @main(%arg0: f64) -> f64 attributes {llvm.emit_c_interface} {
    %cst = arith.constant 0.000000e+00 : f64
    %0 = arith.cmpf ogt, %arg0, %cst : f64
    %1 = scf.if %0 -> (f64) {
      %2 = arith.negf %arg0 : f64
      scf.yield %2 : f64
    } else {
      scf.yield %arg0 : f64
    }
    return %1 : f64
  }
}

Convert everything to negative
------------------------------
 2.3 -> -2.3
-1.5 -> -1.5
 0.0 -> 0.0
-5.6 -> -5.6
 5.9 -> -5.9
