Skip to content

Commit

Permalink
Add Python bindings for affine expressions with binary operators.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 283569325
  • Loading branch information
MLIR Team authored and tensorflower-gardener committed Dec 3, 2019
1 parent 06f6958 commit 343469c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
63 changes: 63 additions & 0 deletions bindings/python/pybind.cpp
Expand Up @@ -207,6 +207,9 @@ struct PythonMLIRModule {
// Creates an affine symbol expression.
PythonAffineExpr affineSymbolExpr(unsigned position);

// Creates an affine dimension expression.
PythonAffineExpr affineDimExpr(unsigned position);

// Creates a single constant result affine map.
PythonAffineMap affineConstantMap(int64_t value);

Expand Down Expand Up @@ -565,6 +568,8 @@ struct PythonAffineExpr {
operator AffineExpr() const { return affine_expr; }
operator AffineExpr &() { return affine_expr; }

AffineExpr get() const { return affine_expr; }

std::string str() const {
std::string res;
llvm::raw_string_ostream os(res);
Expand Down Expand Up @@ -724,6 +729,10 @@ PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) {
return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext));
}

PythonAffineExpr PythonMLIRModule::affineDimExpr(unsigned position) {
return PythonAffineExpr(getAffineDimExpr(position, &mlirContext));
}

PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) {
return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext));
}
Expand Down Expand Up @@ -937,6 +946,8 @@ PYBIND11_MODULE(pybind, m) {
"Returns an affine constant expression.")
.def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr,
"Returns an affine symbol expression.")
.def("affine_dim_expr", &PythonMLIRModule::affineDimExpr,
"Returns an affine dim expression.")
.def("affine_constant_map", &PythonMLIRModule::affineConstantMap,
"Returns an affine map with single constant result.")
.def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.",
Expand Down Expand Up @@ -1054,6 +1065,58 @@ PYBIND11_MODULE(pybind, m) {
py::class_<PythonAffineExpr>(m, "AffineExpr",
"A wrapper around mlir::AffineExpr")
.def(py::init<PythonAffineExpr>())
.def("__add__",
[](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() + rhs);
})
.def("__add__",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() + rhs.get());
})
.def("__neg__",
[](PythonAffineExpr lhs) -> PythonAffineExpr {
return PythonAffineExpr(-lhs.get());
})
.def("__sub__",
[](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() - rhs);
})
.def("__sub__",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() - rhs.get());
})
.def("__mul__",
[](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() * rhs);
})
.def("__mul__",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() * rhs.get());
})
.def("__floordiv__",
[](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get().floorDiv(rhs));
})
.def("__floordiv__",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get().floorDiv(rhs.get()));
})
.def("ceildiv",
[](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get().ceilDiv(rhs));
})
.def("ceildiv",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get().ceilDiv(rhs.get()));
})
.def("__mod__",
[](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() % rhs);
})
.def("__mod__",
[](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
return PythonAffineExpr(lhs.get() % rhs.get());
})
.def("__str__", &PythonAffineExpr::str);

py::class_<PythonAffineMap>(m, "AffineMap",
Expand Down
12 changes: 10 additions & 2 deletions bindings/python/test/test_py2and3.py
Expand Up @@ -289,22 +289,30 @@ def testFunctionDeclarationWithAffineAttr(self):
self.setUp()
a1 = self.module.affine_constant_expr(23)
a2 = self.module.affine_constant_expr(44)
a3 = self.module.affine_dim_expr(1)
s0 = self.module.affine_symbol_expr(0)
aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
aMap2 = self.module.affine_constant_map(42)
aMap3 = self.module.affine_map(
2, 0,
[a1 + a2 * a3, a1 // a3 % a2,
a1.ceildiv(a2), a1 - 2, a2 * 2, -a3])

affineAttr1 = self.module.affineMapAttr(aMap1)
affineAttr2 = self.module.affineMapAttr(aMap2)
affineAttr3 = self.module.affineMapAttr(aMap3)

t = self.module.make_memref_type(self.f32Type, [10])
t_with_attr = t({
"affine_attr_1": affineAttr1,
"affine_attr_2": affineAttr2
"affine_attr_2": affineAttr2,
"affine_attr_3": affineAttr3,
})

f = self.module.declare_function("foo", [t, t_with_attr], [])
printWithCurrentFunctionName(str(self.module))
# CHECK-LABEL: testFunctionDeclarationWithAffineAttr
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42)})
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42), affine_attr_3 = (d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)})

def testFunctionDeclarationWithArrayAttr(self):
self.setUp()
Expand Down

0 comments on commit 343469c

Please sign in to comment.