579 changes: 554 additions & 25 deletions mlir/lib/Bindings/Python/IRModules.cpp

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions mlir/lib/Bindings/Python/IRModules.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "PybindUtils.h"

#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/DenseMap.h"
Expand Down Expand Up @@ -668,6 +669,34 @@ class PyValue {
MlirValue value;
};

/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
class PyAffineExpr : public BaseContextObject {
public:
PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
: BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
bool operator==(const PyAffineExpr &other);
operator MlirAffineExpr() const { return affineExpr; }
MlirAffineExpr get() const { return affineExpr; }

/// Gets a capsule wrapping the void* within the MlirAffineExpr.
pybind11::object getCapsule();

/// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
/// Note that PyAffineExpr instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
/// is taken by calling this function.
static PyAffineExpr createFromCapsule(pybind11::object capsule);

PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
PyAffineExpr floorDiv(const PyAffineExpr &other) const;
PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
PyAffineExpr mod(const PyAffineExpr &other) const;

private:
MlirAffineExpr affineExpr;
};

class PyAffineMap : public BaseContextObject {
public:
PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/CAPI/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr) {
return wrap(unwrap(affineExpr).getContext());
}

bool mlirAffineExprEqual(MlirAffineExpr lhs, MlirAffineExpr rhs) {
return unwrap(lhs) == unwrap(rhs);
}

void mlirAffineExprPrint(MlirAffineExpr affineExpr, MlirStringCallback callback,
void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
Expand Down Expand Up @@ -56,6 +60,10 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//

bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) {
return unwrap(affineExpr).isa<AffineDimExpr>();
}

MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) {
return wrap(getAffineDimExpr(position, unwrap(ctx)));
}
Expand All @@ -68,6 +76,10 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) {
// Affine Symbol Expression.
//===----------------------------------------------------------------------===//

bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) {
return unwrap(affineExpr).isa<AffineSymbolExpr>();
}

MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) {
return wrap(getAffineSymbolExpr(position, unwrap(ctx)));
}
Expand All @@ -80,6 +92,10 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) {
// Affine Constant Expression.
//===----------------------------------------------------------------------===//

bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) {
return unwrap(affineExpr).isa<AffineConstantExpr>();
}

MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) {
return wrap(getAffineConstantExpr(constant, unwrap(ctx)));
}
Expand Down Expand Up @@ -159,6 +175,10 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs,
// Affine Binary Operation Expression.
//===----------------------------------------------------------------------===//

bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) {
return unwrap(affineExpr).isa<AffineBinaryOpExpr>();
}

MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) {
return wrap(unwrap(affineExpr).cast<AffineBinaryOpExpr>().getLHS());
}
Expand Down
17 changes: 15 additions & 2 deletions mlir/lib/CAPI/IR/AffineMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/AffineExpr.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Utils.h"
Expand Down Expand Up @@ -37,11 +38,19 @@ MlirAffineMap mlirAffineMapEmptyGet(MlirContext ctx) {
return wrap(AffineMap::get(unwrap(ctx)));
}

MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount,
intptr_t symbolCount) {
MlirAffineMap mlirAffineMapZeroResultGet(MlirContext ctx, intptr_t dimCount,
intptr_t symbolCount) {
return wrap(AffineMap::get(dimCount, symbolCount, unwrap(ctx)));
}

MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount,
intptr_t symbolCount, intptr_t nAffineExprs,
MlirAffineExpr *affineExprs) {
SmallVector<AffineExpr, 4> exprs;
ArrayRef<AffineExpr> exprList = unwrapList(nAffineExprs, affineExprs, exprs);
return wrap(AffineMap::get(dimCount, symbolCount, exprList, unwrap(ctx)));
}

MlirAffineMap mlirAffineMapConstantGet(MlirContext ctx, int64_t val) {
return wrap(AffineMap::getConstantMap(val, unwrap(ctx)));
}
Expand Down Expand Up @@ -94,6 +103,10 @@ intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap) {
return unwrap(affineMap).getNumResults();
}

MlirAffineExpr mlirAffineMapGetResult(MlirAffineMap affineMap, intptr_t pos) {
return wrap(unwrap(affineMap).getResult(static_cast<unsigned>(pos)));
}

intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap) {
return unwrap(affineMap).getNumInputs();
}
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
unwrap(elementType), maps, memorySpace));
}

MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
const int64_t *shape, intptr_t numMaps,
MlirAffineMap const *affineMaps,
unsigned memorySpace, MlirLocation loc) {
SmallVector<AffineMap, 1> maps;
(void)unwrapList(numMaps, affineMaps, maps);
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), maps, memorySpace));
}

MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
const int64_t *shape,
unsigned memorySpace) {
Expand Down
275 changes: 275 additions & 0 deletions mlir/test/Bindings/Python/ir_affine_expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
from mlir.ir import *

def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0


# CHECK-LABEL: TEST: testAffineExprCapsule
def testAffineExprCapsule():
with Context() as ctx:
affine_expr = AffineExpr.get_constant(42)

affine_expr_capsule = affine_expr._CAPIPtr
# CHECK: capsule object
# CHECK: mlir.ir.AffineExpr._CAPIPtr
print(affine_expr_capsule)

affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
assert affine_expr == affine_expr_2
assert affine_expr_2.context == ctx

run(testAffineExprCapsule)


# CHECK-LABEL: TEST: testAffineExprEq
def testAffineExprEq():
with Context():
a1 = AffineExpr.get_constant(42)
a2 = AffineExpr.get_constant(42)
a3 = AffineExpr.get_constant(43)
# CHECK: True
print(a1 == a1)
# CHECK: True
print(a1 == a2)
# CHECK: False
print(a1 == a3)
# CHECK: False
print(a1 == None)
# CHECK: False
print(a1 == "foo")

run(testAffineExprEq)


# CHECK-LABEL: TEST: testAffineExprContext
def testAffineExprContext():
with Context():
a1 = AffineExpr.get_constant(42)
with Context():
a2 = AffineExpr.get_constant(42)

# CHECK: False
print(a1 == a2)

run(testAffineExprContext)


# CHECK-LABEL: TEST: testAffineExprConstant
def testAffineExprConstant():
with Context():
a1 = AffineExpr.get_constant(42)
# CHECK: 42
print(a1.value)
# CHECK: 42
print(a1)

a2 = AffineConstantExpr.get(42)
# CHECK: 42
print(a2.value)
# CHECK: 42
print(a2)

assert a1 == a2

run(testAffineExprConstant)


# CHECK-LABEL: TEST: testAffineExprDim
def testAffineExprDim():
with Context():
d1 = AffineExpr.get_dim(1)
d11 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)

# CHECK: 1
print(d1.position)
# CHECK: d1
print(d1)

# CHECK: 2
print(d2.position)
# CHECK: d2
print(d2)

assert d1 == d11
assert d1 != d2

run(testAffineExprDim)


# CHECK-LABEL: TEST: testAffineExprSymbol
def testAffineExprSymbol():
with Context():
s1 = AffineExpr.get_symbol(1)
s11 = AffineSymbolExpr.get(1)
s2 = AffineSymbolExpr.get(2)

# CHECK: 1
print(s1.position)
# CHECK: s1
print(s1)

# CHECK: 2
print(s2.position)
# CHEKC: s2
print(s2)

assert s1 == s11
assert s1 != s2

run(testAffineExprSymbol)


# CHECK-LABEL: TEST: testAffineAddExpr
def testAffineAddExpr():
with Context():
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
d12 = AffineExpr.get_add(d1, d2)
# CHECK: d1 + d2
print(d12)

d12op = d1 + d2
# CHECK: d1 + d2
print(d12op)

assert d12 == d12op
assert d12.lhs == d1
assert d12.rhs == d2

run(testAffineAddExpr)


# CHECK-LABEL: TEST: testAffineMulExpr
def testAffineMulExpr():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
expr = AffineExpr.get_mul(d1, c2)
# CHECK: d1 * 2
print(expr)

# CHECK: d1 * 2
op = d1 * c2
print(op)

assert expr == op
assert expr.lhs == d1
assert expr.rhs == c2

run(testAffineMulExpr)


# CHECK-LABEL: TEST: testAffineModExpr
def testAffineModExpr():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
expr = AffineExpr.get_mod(d1, c2)
# CHECK: d1 mod 2
print(expr)

# CHECK: d1 mod 2
op = d1 % c2
print(op)

assert expr == op
assert expr.lhs == d1
assert expr.rhs == c2

run(testAffineModExpr)


# CHECK-LABEL: TEST: testAffineFloorDivExpr
def testAffineFloorDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
expr = AffineExpr.get_floor_div(d1, c2)
# CHECK: d1 floordiv 2
print(expr)

assert expr.lhs == d1
assert expr.rhs == c2

run(testAffineFloorDivExpr)


# CHECK-LABEL: TEST: testAffineCeilDivExpr
def testAffineCeilDivExpr():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
expr = AffineExpr.get_ceil_div(d1, c2)
# CHECK: d1 ceildiv 2
print(expr)

assert expr.lhs == d1
assert expr.rhs == c2

run(testAffineCeilDivExpr)


# CHECK-LABEL: TEST: testAffineExprSub
def testAffineExprSub():
with Context():
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
expr = d1 - d2
# CHECK: d1 - d2
print(expr)

assert expr.lhs == d1
rhs = AffineMulExpr(expr.rhs)
# CHECK: d2
print(rhs.lhs)
# CHECK: -1
print(rhs.rhs)

run(testAffineExprSub)


def testClassHierarchy():
with Context():
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
add = AffineAddExpr.get(d1, c2)
mul = AffineMulExpr.get(d1, c2)
mod = AffineModExpr.get(d1, c2)
floor_div = AffineFloorDivExpr.get(d1, c2)
ceil_div = AffineCeilDivExpr.get(d1, c2)

# CHECK: False
print(isinstance(d1, AffineBinaryExpr))
# CHECK: False
print(isinstance(c2, AffineBinaryExpr))
# CHECK: True
print(isinstance(add, AffineBinaryExpr))
# CHECK: True
print(isinstance(mul, AffineBinaryExpr))
# CHECK: True
print(isinstance(mod, AffineBinaryExpr))
# CHECK: True
print(isinstance(floor_div, AffineBinaryExpr))
# CHECK: True
print(isinstance(ceil_div, AffineBinaryExpr))

try:
AffineBinaryExpr(d1)
except ValueError as e:
# CHECK: Cannot cast affine expression to AffineBinaryExpr
print(e)

try:
AffineBinaryExpr(c2)
except ValueError as e:
# CHECK: Cannot cast affine expression to AffineBinaryExpr
print(e)

run(testClassHierarchy)
148 changes: 148 additions & 0 deletions mlir/test/Bindings/Python/ir_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,151 @@ def testAffineMapCapsule():
assert am2.context is ctx

run(testAffineMapCapsule)


# CHECK-LABEL: TEST: testAffineMapGet
def testAffineMapGet():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)

# CHECK: (d0, d1)[s0, s1, s2] -> ()
map0 = AffineMap.get(2, 3, [])
print(map0)

# CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2)
map1 = AffineMap.get(2, 3, [d1, c2])
print(map1)

# CHECK: () -> (2)
map2 = AffineMap.get(0, 0, [c2])
print(map2)

# CHECK: (d0, d1) -> (d0, d1)
map3 = AffineMap.get(2, 0, [d0, d1])
print(map3)

# CHECK: (d0, d1) -> (d1)
map4 = AffineMap.get(2, 0, [d1])
print(map4)

# CHECK: (d0, d1, d2) -> (d2, d0, d1)
map5 = AffineMap.get_permutation([2, 0, 1])
print(map5)

assert map1 == AffineMap.get(2, 3, [d1, c2])
assert AffineMap.get(0, 0, []) == AffineMap.get_empty()
assert map2 == AffineMap.get_constant(2)
assert map3 == AffineMap.get_identity(2)
assert map4 == AffineMap.get_minor_identity(2, 1)

try:
AffineMap.get(1, 1, [1])
except RuntimeError as e:
# CHECK: Invalid expression when attempting to create an AffineMap
print(e)

try:
AffineMap.get(1, 1, [None])
except RuntimeError as e:
# CHECK: Invalid expression (None?) when attempting to create an AffineMap
print(e)

try:
map3.get_submap([42])
except ValueError as e:
# CHECK: result position out of bounds
print(e)

try:
map3.get_minor_submap(42)
except ValueError as e:
# CHECK: number of results out of bounds
print(e)

try:
map3.get_major_submap(42)
except ValueError as e:
# CHECK: number of results out of bounds
print(e)

run(testAffineMapGet)


# CHECK-LABEL: TEST: testAffineMapDerive
def testAffineMapDerive():
with Context() as ctx:
map5 = AffineMap.get_identity(5)

# CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3)
map123 = map5.get_submap([1,2,3])
print(map123)

# CHECK: (d0, d1, d2, d3, d4) -> (d0, d1)
map01 = map5.get_major_submap(2)
print(map01)

# CHECK: (d0, d1, d2, d3, d4) -> (d3, d4)
map34 = map5.get_minor_submap(2)
print(map34)

run(testAffineMapDerive)


# CHECK-LABEL: TEST: testAffineMapProperties
def testAffineMapProperties():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map1 = AffineMap.get(3, 0, [d2, d0])
map2 = AffineMap.get(3, 0, [d2, d0, d1])
map3 = AffineMap.get(3, 1, [d2, d0, d1])
# CHECK: False
print(map1.is_permutation)
# CHECK: True
print(map1.is_projected_permutation)
# CHECK: True
print(map2.is_permutation)
# CHECK: True
print(map2.is_projected_permutation)
# CHECK: False
print(map3.is_permutation)
# CHECK: False
print(map3.is_projected_permutation)

run(testAffineMapProperties)


# CHECK-LABEL: TEST: testAffineMapExprs
def testAffineMapExprs():
with Context():
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
d2 = AffineDimExpr.get(2)
map3 = AffineMap.get(3, 1, [d2, d0, d1])

# CHECK: 3
print(map3.n_dims)
# CHECK: 4
print(map3.n_inputs)
# CHECK: 1
print(map3.n_symbols)
assert map3.n_inputs == map3.n_dims + map3.n_symbols

# CHECK: 3
print(len(map3.results))
for expr in map3.results:
# CHECK: d2
# CHECK: d0
# CHECK: d1
print(expr)
for expr in map3.results[-1:-4:-1]:
# CHECK: d1
# CHECK: d0
# CHECK: d2
print(expr)
assert list(map3.results) == [d2, d0, d1]

run(testAffineMapExprs)
16 changes: 13 additions & 3 deletions mlir/test/Bindings/Python/ir_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,27 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
memref = MemRefType.get_contiguous_memref(f32, shape, 2)
memref = MemRefType.get(f32, shape, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
print("number of affine layout maps:", memref.num_affine_maps)
print("number of affine layout maps:", len(memref.layout))
# CHECK: memory space: 2
print("memory space:", memref.memory_space)

layout = AffineMap.get_permutation([1, 0])
memref_layout = MemRefType.get(f32, shape, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
# CHECK: memref layout: (d0, d1) -> (d1, d0)
print("memref layout:", memref_layout.layout[0])
# CHECK: memory space: 0
print("memory space:", memref_layout.memory_space)

none = NoneType.get()
try:
memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2)
memref_invalid = MemRefType.get(none, shape)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
Expand Down
50 changes: 48 additions & 2 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ int printBuiltinAttributes(MlirContext ctx) {

int printAffineMap(MlirContext ctx) {
MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2);
MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
MlirAffineMap multiDimIdentityAffineMap =
mlirAffineMapMultiDimIdentityGet(ctx, 3);
Expand Down Expand Up @@ -1251,6 +1251,50 @@ int printAffineExpr(MlirContext ctx) {
if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
return 13;

if (!mlirAffineExprIsABinary(affineAddExpr))
return 14;

// Test other 'IsA' method on affine expressions.
if (!mlirAffineExprIsAConstant(affineConstantExpr))
return 15;

if (!mlirAffineExprIsADim(affineDimExpr))
return 16;

if (!mlirAffineExprIsASymbol(affineSymbolExpr))
return 17;

// Test equality and nullity.
MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
return 18;

if (mlirAffineExprIsNull(affineDimExpr))
return 19;

return 0;
}

int affineMapFromExprs(MlirContext ctx) {
MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);

// CHECK-LABEL: @affineMapFromExprs
fprintf(stderr, "@affineMapFromExprs");
// CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
mlirAffineMapDump(map);

if (mlirAffineMapGetNumResults(map) != 2)
return 1;

if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
return 2;

if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
return 3;

return 0;
}

Expand Down Expand Up @@ -1354,8 +1398,10 @@ int main() {
return 4;
if (printAffineExpr(ctx))
return 5;
if (registerOnlyStd())
if (affineMapFromExprs(ctx))
return 6;
if (registerOnlyStd())
return 7;

mlirContextDestroy(ctx);

Expand Down