Skip to content

Commit

Permalink
[mlir][python] Add python binding for AffineMapAttribute.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D96815
  • Loading branch information
stellaraccident committed Feb 16, 2021
1 parent 60d71a2 commit 4c3f1be
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
28 changes: 23 additions & 5 deletions mlir/lib/Bindings/Python/IRModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,23 @@ class PyConcreteAttribute : public BaseTy {
static void bindDerived(ClassTy &m) {}
};

class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
static constexpr const char *pyClassName = "AffineMapAttr";
using PyConcreteAttribute::PyConcreteAttribute;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyAffineMap &affineMap) {
MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
return PyAffineMapAttribute(affineMap.getContext(), attr);
},
py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
}
};

class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
Expand Down Expand Up @@ -3994,17 +4011,18 @@ void mlir::python::populateIRSubmodule(py::module &m) {
"The underlying generic attribute of the NamedAttribute binding");

// Builtin attribute bindings.
PyFloatAttribute::bind(m);
PyAffineMapAttribute::bind(m);
PyArrayAttribute::bind(m);
PyArrayAttribute::PyArrayAttributeIterator::bind(m);
PyIntegerAttribute::bind(m);
PyBoolAttribute::bind(m);
PyFlatSymbolRefAttribute::bind(m);
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyDenseFPElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyDictAttribute::bind(m);
PyFlatSymbolRefAttribute::bind(m);
PyFloatAttribute::bind(m);
PyIntegerAttribute::bind(m);
PyStringAttribute::bind(m);
PyTypeAttribute::bind(m);
PyUnitAttribute::bind(m);

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Bindings/Python/ir_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def testStandardAttrCasts():
run(testStandardAttrCasts)


# CHECK-LABEL: TEST: testAffineMapAttr
def testAffineMapAttr():
with Context() as ctx:
d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
c2 = AffineConstantExpr.get(2)
map0 = AffineMap.get(2, 3, [])

# CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
attr_built = AffineMapAttr.get(map0)
print(str(attr_built))

attr_parsed = Attribute.parse(str(attr_built))
assert attr_built == attr_parsed

run(testAffineMapAttr)


# CHECK-LABEL: TEST: testFloatAttr
def testFloatAttr():
with Context(), Location.unknown():
Expand Down

0 comments on commit 4c3f1be

Please sign in to comment.