diff --git a/include/circt-c/Dialect/OM.h b/include/circt-c/Dialect/OM.h index a9dfef15663d..bf827ebbea73 100644 --- a/include/circt-c/Dialect/OM.h +++ b/include/circt-c/Dialect/OM.h @@ -105,6 +105,11 @@ MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMObject object); MLIR_CAPI_EXPORTED OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name); +/// Get all the field names from an Object, can be empty if object has no +/// fields. +MLIR_CAPI_EXPORTED MlirAttribute +omEvaluatorObjectGetFieldNames(OMObject object); + //===----------------------------------------------------------------------===// // ObjectValue API. //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 0d5faca7a5e0..fba430e35306 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -81,6 +81,9 @@ struct Object : std::enable_shared_from_this { /// Get a field of the Object by name. FailureOr getField(StringAttr name); + /// Get all the field names of the Object. + ArrayAttr getFieldNames(); + private: /// Allow the instantiate method as a friend to construct Objects. friend FailureOr> diff --git a/integration_test/Bindings/Python/dialects/om.py b/integration_test/Bindings/Python/dialects/om.py index 59b286449d7d..f2c874ca3712 100644 --- a/integration_test/Bindings/Python/dialects/om.py +++ b/integration_test/Bindings/Python/dialects/om.py @@ -58,3 +58,8 @@ print(obj.field) # CHECK: 14 print(obj.child.foo) + +for (name, field) in obj: + # CHECK: name: child, field: #include namespace py = pybind11; @@ -55,6 +56,18 @@ struct Object { return omEvaluatorObjectValueGetPrimitive(result); } + // Get a list with the names of all the fields in the Object. + std::vector getFieldNames() { + ArrayAttr fieldNames = + cast(unwrap(omEvaluatorObjectGetFieldNames(object))); + + std::vector slots; + for (auto fieldName : fieldNames.getAsRange()) + slots.push_back(fieldName.str()); + + return slots; + } + private: // The underlying CAPI OMObject. OMObject object; @@ -108,6 +121,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) { .def(py::init(), py::arg("object")) .def("__getattr__", &Object::getField, "Get a field from an Object", py::arg("name")) + .def_property_readonly("field_names", &Object::getFieldNames, + "Get field names from an Object") .def_property_readonly("type", &Object::getType, "The Type of the Object"); diff --git a/lib/Bindings/Python/dialects/om.py b/lib/Bindings/Python/dialects/om.py index 9557b5f4df23..2d8e3b5ccf26 100644 --- a/lib/Bindings/Python/dialects/om.py +++ b/lib/Bindings/Python/dialects/om.py @@ -37,6 +37,11 @@ def __getattr__(self, name: str): assert isinstance(field, BaseObject) return Object(field) + # Support iterating over an Object by yielding its fields. + def __iter__(self): + for name in self.field_names: + yield (name, getattr(self, name)) + # Define the Evaluator class by inheriting from the base implementation in C++. class Evaluator(BaseEvaluator): diff --git a/lib/CAPI/Dialect/OM.cpp b/lib/CAPI/Dialect/OM.cpp index e0f457ed28e9..ca70e288bfbe 100644 --- a/lib/CAPI/Dialect/OM.cpp +++ b/lib/CAPI/Dialect/OM.cpp @@ -114,6 +114,11 @@ MlirType omEvaluatorObjectGetType(OMObject object) { return wrap(unwrap(object)->getType()); } +/// Get an ArrayAttr with the names of the fields in an Object. +MlirAttribute omEvaluatorObjectGetFieldNames(OMObject object) { + return wrap(unwrap(object)->getFieldNames()); +} + /// Get a field from an Object, which must contain a field of that name. OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) { // Unwrap the Object and get the field of the name, which the client must diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 7c6698887caf..b4ac0bbf1865 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -203,3 +203,17 @@ FailureOr circt::om::Object::getField(StringAttr name) { return cls.emitError("field ") << name << " does not exist"; return success(fields[name]); } + +/// Get an ArrayAttr with the names of the fields in the Object. Sort the fields +/// so there is always a stable order. +ArrayAttr circt::om::Object::getFieldNames() { + SmallVector fieldNames; + for (auto &f : fields) + fieldNames.push_back(f.first); + + llvm::sort(fieldNames, [](Attribute a, Attribute b) { + return cast(a).getValue() < cast(b).getValue(); + }); + + return ArrayAttr::get(cls.getContext(), fieldNames); +} diff --git a/test/CAPI/om.c b/test/CAPI/om.c index 3c48789c3ce4..b2eda4e90f24 100644 --- a/test/CAPI/om.c +++ b/test/CAPI/om.c @@ -107,6 +107,10 @@ void testEvaluator(MlirContext ctx) { OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName); + MlirAttribute fieldNamesO = omEvaluatorObjectGetFieldNames(object); + // CHECK: ["child", "field"] + mlirAttributeDump(fieldNamesO); + OMObject child = omEvaluatorObjectValueGetObject(childField); // CHECK: 0 @@ -115,6 +119,11 @@ void testEvaluator(MlirContext ctx) { OMObjectValue foo = omEvaluatorObjectGetField( child, mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo"))); + MlirAttribute fieldNamesC = omEvaluatorObjectGetFieldNames(child); + + // CHECK: ["foo"] + mlirAttributeDump(fieldNamesC); + // CHECK: child object field is primitive: 1 fprintf(stderr, "child object field is primitive: %d\n", omEvaluatorObjectValueIsAPrimitive(foo)); diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index da061511b3ec..800a32d97f71 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -13,6 +13,9 @@ #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "gtest/gtest.h" #include @@ -378,6 +381,16 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) { auto field2Value = std::get>( result.value()->getField(builder.getStringAttr("field2")).value()); + auto fieldNames = result.value()->getFieldNames(); + + ASSERT_TRUE(fieldNames.size() == 2); + StringRef fieldNamesTruth[] = {"field1", "field2"}; + for (auto fieldName : llvm::enumerate(fieldNames)) { + auto str = llvm::dyn_cast_or_null(fieldName.value()); + ASSERT_TRUE(str); + ASSERT_EQ(str.getValue(), fieldNamesTruth[fieldName.index()]); + } + ASSERT_TRUE(field1Value); ASSERT_TRUE(field2Value);