diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b1710656243a..06d0256e6287b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2730,6 +2730,17 @@ class PyOpAttributeMap { operation->get(), toMlirStringRef(name))); } + static void + forEachAttr(MlirOperation op, + llvm::function_ref fn) { + intptr_t n = mlirOperationGetNumAttributes(op); + for (intptr_t i = 0; i < n; ++i) { + MlirNamedAttribute na = mlirOperationGetAttribute(op, i); + MlirStringRef name = mlirIdentifierStr(na.name); + fn(name, na.attribute); + } + } + static void bind(nb::module_ &m) { nb::class_(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) @@ -2737,7 +2748,50 @@ class PyOpAttributeMap { .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem); + .def("__delitem__", &PyOpAttributeMap::dunderDelItem) + .def("__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }) + .def("keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }) + .def("values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }) + .def("items", [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }); } private: diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index cb4cfc8c8a6ec..1d4ede1ee336a 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -569,12 +569,30 @@ def testOperationAttributes(): # CHECK: Attribute value b'text' print(f"Attribute value {sattr.value_bytes}") + # Python dict-style iteration # We don't know in which order the attributes are stored. - # CHECK-DAG: NamedAttribute(dependent="text") - # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) - # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) - for attr in op.attributes: - print(str(attr)) + # CHECK-DAG: dependent + # CHECK-DAG: other.attribute + # CHECK-DAG: some.attribute + for name in op.attributes: + print(name) + + # Basic dict-like introspection + # CHECK: True + print("some.attribute" in op.attributes) + # CHECK: False + print("missing" in op.attributes) + # CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute'] + print("Keys:", sorted(op.attributes.keys())) + # CHECK: Values count 3 + print("Values count", len(op.attributes.values())) + # CHECK: Items count 3 + print("Items count", len(op.attributes.items())) + + # Dict() conversion test + d = {k: v.value for k, v in dict(op.attributes).items()} + # CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1} + print("Dict mapping", d) # Check that exceptions are raised as expected. try: