From e7338689f65bdd37b99cea6c51cec9a98ecf9155 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Mon, 12 Jun 2023 21:34:28 -0700 Subject: [PATCH 1/6] Get fields from an object --- include/circt-c/Dialect/OM.h | 5 +++++ include/circt/Dialect/OM/Evaluator/Evaluator.h | 3 +++ lib/Bindings/Python/OMModule.cpp | 7 +++++++ lib/CAPI/Dialect/OM.cpp | 5 +++++ lib/Dialect/OM/Evaluator/Evaluator.cpp | 9 +++++++++ 5 files changed, 29 insertions(+) diff --git a/include/circt-c/Dialect/OM.h b/include/circt-c/Dialect/OM.h index a9dfef15663d..bf827ebbea73 100644 --- a/include/circt-c/Dialect/OM.h +++ b/include/circt-c/Dialect/OM.h @@ -105,6 +105,11 @@ MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMObject object); MLIR_CAPI_EXPORTED OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name); +/// Get all the field names from an Object, can be empty if object has no +/// fields. +MLIR_CAPI_EXPORTED MlirAttribute +omEvaluatorObjectGetFieldNames(OMObject object); + //===----------------------------------------------------------------------===// // ObjectValue API. //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 0d5faca7a5e0..fba430e35306 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -81,6 +81,9 @@ struct Object : std::enable_shared_from_this { /// Get a field of the Object by name. FailureOr getField(StringAttr name); + /// Get all the field names of the Object. + ArrayAttr getFieldNames(); + private: /// Allow the instantiate method as a friend to construct Objects. friend FailureOr> diff --git a/lib/Bindings/Python/OMModule.cpp b/lib/Bindings/Python/OMModule.cpp index 8b8de8ce238a..378ce709cfba 100644 --- a/lib/Bindings/Python/OMModule.cpp +++ b/lib/Bindings/Python/OMModule.cpp @@ -55,6 +55,11 @@ struct Object { return omEvaluatorObjectValueGetPrimitive(result); } + // Get an ArrayAttr with the names of all the fields in the object. + MlirAttribute getFieldNames() { + return omEvaluatorObjectGetFieldNames(object); + } + private: // The underlying CAPI OMObject. OMObject object; @@ -108,6 +113,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) { .def(py::init(), py::arg("object")) .def("__getattr__", &Object::getField, "Get a field from an Object", py::arg("name")) + .def("getFieldNames", &Object::getFieldNames, + "Get field names from an Object") .def_property_readonly("type", &Object::getType, "The Type of the Object"); diff --git a/lib/CAPI/Dialect/OM.cpp b/lib/CAPI/Dialect/OM.cpp index e0f457ed28e9..ca70e288bfbe 100644 --- a/lib/CAPI/Dialect/OM.cpp +++ b/lib/CAPI/Dialect/OM.cpp @@ -114,6 +114,11 @@ MlirType omEvaluatorObjectGetType(OMObject object) { return wrap(unwrap(object)->getType()); } +/// Get an ArrayAttr with the names of the fields in an Object. +MlirAttribute omEvaluatorObjectGetFieldNames(OMObject object) { + return wrap(unwrap(object)->getFieldNames()); +} + /// Get a field from an Object, which must contain a field of that name. OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) { // Unwrap the Object and get the field of the name, which the client must diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 7c6698887caf..5b255980b1a4 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -203,3 +203,12 @@ FailureOr circt::om::Object::getField(StringAttr name) { return cls.emitError("field ") << name << " does not exist"; return success(fields[name]); } + +/// Get an ArrayAttr with the names of the fields in the Object. +ArrayAttr circt::om::Object::getFieldNames() { + SmallVector fieldNames; + for (auto &f : fields) + fieldNames.push_back(f.first); + + return ArrayAttr::get(cls.getContext(), fieldNames); +} From 325bcd10f76400581835694534f5195308f8c271 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Wed, 14 Jun 2023 13:46:27 -0400 Subject: [PATCH 2/6] Add tests --- test/CAPI/om.c | 9 +++++++++ unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/test/CAPI/om.c b/test/CAPI/om.c index 3c48789c3ce4..5af022de8921 100644 --- a/test/CAPI/om.c +++ b/test/CAPI/om.c @@ -107,6 +107,10 @@ void testEvaluator(MlirContext ctx) { OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName); + MlirAttribute fieldNamesO = omEvaluatorObjectGetFieldNames(object); + // CHECK: ["field", "child"] + mlirAttributeDump(fieldNamesO); + OMObject child = omEvaluatorObjectValueGetObject(childField); // CHECK: 0 @@ -115,6 +119,11 @@ void testEvaluator(MlirContext ctx) { OMObjectValue foo = omEvaluatorObjectGetField( child, mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo"))); + MlirAttribute fieldNamesC = omEvaluatorObjectGetFieldNames(child); + + // CHECK: ["foo"] + mlirAttributeDump(fieldNamesC); + // CHECK: child object field is primitive: 1 fprintf(stderr, "child object field is primitive: %d\n", omEvaluatorObjectValueIsAPrimitive(foo)); diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index da061511b3ec..800a32d97f71 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -13,6 +13,9 @@ #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "gtest/gtest.h" #include @@ -378,6 +381,16 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) { auto field2Value = std::get>( result.value()->getField(builder.getStringAttr("field2")).value()); + auto fieldNames = result.value()->getFieldNames(); + + ASSERT_TRUE(fieldNames.size() == 2); + StringRef fieldNamesTruth[] = {"field1", "field2"}; + for (auto fieldName : llvm::enumerate(fieldNames)) { + auto str = llvm::dyn_cast_or_null(fieldName.value()); + ASSERT_TRUE(str); + ASSERT_EQ(str.getValue(), fieldNamesTruth[fieldName.index()]); + } + ASSERT_TRUE(field1Value); ASSERT_TRUE(field2Value); From fdc18871db79631e1a081ab9937d8a22309f8246 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 14 Jun 2023 22:27:46 -0600 Subject: [PATCH 3/6] Sort fields before returning them. --- lib/Dialect/OM/Evaluator/Evaluator.cpp | 7 ++++++- test/CAPI/om.c | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 5b255980b1a4..b4ac0bbf1865 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -204,11 +204,16 @@ FailureOr circt::om::Object::getField(StringAttr name) { return success(fields[name]); } -/// Get an ArrayAttr with the names of the fields in the Object. +/// Get an ArrayAttr with the names of the fields in the Object. Sort the fields +/// so there is always a stable order. ArrayAttr circt::om::Object::getFieldNames() { SmallVector fieldNames; for (auto &f : fields) fieldNames.push_back(f.first); + llvm::sort(fieldNames, [](Attribute a, Attribute b) { + return cast(a).getValue() < cast(b).getValue(); + }); + return ArrayAttr::get(cls.getContext(), fieldNames); } diff --git a/test/CAPI/om.c b/test/CAPI/om.c index 5af022de8921..b2eda4e90f24 100644 --- a/test/CAPI/om.c +++ b/test/CAPI/om.c @@ -108,7 +108,7 @@ void testEvaluator(MlirContext ctx) { OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName); MlirAttribute fieldNamesO = omEvaluatorObjectGetFieldNames(object); - // CHECK: ["field", "child"] + // CHECK: ["child", "field"] mlirAttributeDump(fieldNamesO); OMObject child = omEvaluatorObjectValueGetObject(childField); From 155abcbbe1cc3fdab73b83fbec2968090cfa7f4f Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 14 Jun 2023 23:39:40 -0600 Subject: [PATCH 4/6] Use field names to support iterating over an Object. --- .../Bindings/Python/dialects/om.py | 5 +++++ lib/Bindings/Python/OMModule.cpp | 18 +++++++++++++----- lib/Bindings/Python/dialects/om.py | 4 ++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/integration_test/Bindings/Python/dialects/om.py b/integration_test/Bindings/Python/dialects/om.py index 59b286449d7d..f2c874ca3712 100644 --- a/integration_test/Bindings/Python/dialects/om.py +++ b/integration_test/Bindings/Python/dialects/om.py @@ -58,3 +58,8 @@ print(obj.field) # CHECK: 14 print(obj.child.foo) + +for (name, field) in obj: + # CHECK: name: child, field: #include namespace py = pybind11; @@ -55,9 +56,16 @@ struct Object { return omEvaluatorObjectValueGetPrimitive(result); } - // Get an ArrayAttr with the names of all the fields in the object. - MlirAttribute getFieldNames() { - return omEvaluatorObjectGetFieldNames(object); + // Get a list with the names of all the fields in the Object. + std::vector getFieldNames() { + ArrayAttr fieldNames = + cast(unwrap(omEvaluatorObjectGetFieldNames(object))); + + std::vector slots; + for (auto fieldName : fieldNames.getAsRange()) + slots.push_back(fieldName.str()); + + return slots; } private: @@ -113,8 +121,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) { .def(py::init(), py::arg("object")) .def("__getattr__", &Object::getField, "Get a field from an Object", py::arg("name")) - .def("getFieldNames", &Object::getFieldNames, - "Get field names from an Object") + .def_property_readonly("field_names", &Object::getFieldNames, + "Get field names from an Object") .def_property_readonly("type", &Object::getType, "The Type of the Object"); diff --git a/lib/Bindings/Python/dialects/om.py b/lib/Bindings/Python/dialects/om.py index 9557b5f4df23..5b8ae2b6811c 100644 --- a/lib/Bindings/Python/dialects/om.py +++ b/lib/Bindings/Python/dialects/om.py @@ -37,6 +37,10 @@ def __getattr__(self, name: str): assert isinstance(field, BaseObject) return Object(field) + # Support iterating over an Object by yielding its fields. + def __iter__(self): + for name in self.field_names: + yield (name, getattr(self, name)) # Define the Evaluator class by inheriting from the base implementation in C++. class Evaluator(BaseEvaluator): From f4e877d515dc58fad2943d370cd58cbecb2dc661 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Thu, 15 Jun 2023 08:04:11 -0700 Subject: [PATCH 5/6] Remove extra whitespace --- lib/Bindings/Python/dialects/om.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Bindings/Python/dialects/om.py b/lib/Bindings/Python/dialects/om.py index 5b8ae2b6811c..ff3e73af2c24 100644 --- a/lib/Bindings/Python/dialects/om.py +++ b/lib/Bindings/Python/dialects/om.py @@ -38,7 +38,7 @@ def __getattr__(self, name: str): return Object(field) # Support iterating over an Object by yielding its fields. - def __iter__(self): + def __iter__(self): for name in self.field_names: yield (name, getattr(self, name)) From 82834c814e22f66c42c33a5753bc3edcca370a04 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Thu, 15 Jun 2023 08:11:35 -0700 Subject: [PATCH 6/6] Update om.py --- lib/Bindings/Python/dialects/om.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Bindings/Python/dialects/om.py b/lib/Bindings/Python/dialects/om.py index ff3e73af2c24..2d8e3b5ccf26 100644 --- a/lib/Bindings/Python/dialects/om.py +++ b/lib/Bindings/Python/dialects/om.py @@ -42,6 +42,7 @@ def __iter__(self): for name in self.field_names: yield (name, getattr(self, name)) + # Define the Evaluator class by inheriting from the base implementation in C++. class Evaluator(BaseEvaluator):