Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OM] Get the field names from an Object #5402

Merged
merged 6 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/circt-c/Dialect/OM.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMObject object);
MLIR_CAPI_EXPORTED OMObjectValue omEvaluatorObjectGetField(OMObject object,
MlirAttribute name);

/// Get all the field names from an Object, can be empty if object has no
/// fields.
MLIR_CAPI_EXPORTED MlirAttribute
omEvaluatorObjectGetFieldNames(OMObject object);

//===----------------------------------------------------------------------===//
// ObjectValue API.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ struct Object : std::enable_shared_from_this<Object> {
/// Get a field of the Object by name.
FailureOr<ObjectValue> getField(StringAttr name);

/// Get all the field names of the Object.
ArrayAttr getFieldNames();

private:
/// Allow the instantiate method as a friend to construct Objects.
friend FailureOr<std::shared_ptr<Object>>
Expand Down
5 changes: 5 additions & 0 deletions integration_test/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@
print(obj.field)
# CHECK: 14
print(obj.child.foo)

for (name, field) in obj:
# CHECK: name: child, field: <circt.dialects.om.Object object
# CHECK: name: field, field: 42
print(f"name: {name}, field: {field}")
15 changes: 15 additions & 0 deletions lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
Expand Down Expand Up @@ -55,6 +56,18 @@ struct Object {
return omEvaluatorObjectValueGetPrimitive(result);
}

// Get a list with the names of all the fields in the Object.
std::vector<std::string> getFieldNames() {
ArrayAttr fieldNames =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot use ArrayAttr in pybind11 code. You must use the CAPI for interacting with MLIR objects/classes... 'cause linking issues.

cast<ArrayAttr>(unwrap(omEvaluatorObjectGetFieldNames(object)));

std::vector<std::string> slots;
for (auto fieldName : fieldNames.getAsRange<StringAttr>())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for StringAttr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, didn't realize that, will create a new PR to fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prithayan no worries, it was my bad, and fixed in 80531eb

slots.push_back(fieldName.str());

return slots;
}

private:
// The underlying CAPI OMObject.
OMObject object;
Expand Down Expand Up @@ -108,6 +121,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) {
.def(py::init<Object>(), py::arg("object"))
.def("__getattr__", &Object::getField, "Get a field from an Object",
py::arg("name"))
.def_property_readonly("field_names", &Object::getFieldNames,
"Get field names from an Object")
.def_property_readonly("type", &Object::getType,
"The Type of the Object");

Expand Down
5 changes: 5 additions & 0 deletions lib/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def __getattr__(self, name: str):
assert isinstance(field, BaseObject)
return Object(field)

# Support iterating over an Object by yielding its fields.
def __iter__(self):
for name in self.field_names:
yield (name, getattr(self, name))


# Define the Evaluator class by inheriting from the base implementation in C++.
class Evaluator(BaseEvaluator):
Expand Down
5 changes: 5 additions & 0 deletions lib/CAPI/Dialect/OM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ MlirType omEvaluatorObjectGetType(OMObject object) {
return wrap(unwrap(object)->getType());
}

/// Get an ArrayAttr with the names of the fields in an Object.
MlirAttribute omEvaluatorObjectGetFieldNames(OMObject object) {
return wrap(unwrap(object)->getFieldNames());
}

/// Get a field from an Object, which must contain a field of that name.
OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) {
// Unwrap the Object and get the field of the name, which the client must
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,17 @@ FailureOr<ObjectValue> circt::om::Object::getField(StringAttr name) {
return cls.emitError("field ") << name << " does not exist";
return success(fields[name]);
}

/// Get an ArrayAttr with the names of the fields in the Object. Sort the fields
/// so there is always a stable order.
ArrayAttr circt::om::Object::getFieldNames() {
SmallVector<Attribute> fieldNames;
for (auto &f : fields)
fieldNames.push_back(f.first);

llvm::sort(fieldNames, [](Attribute a, Attribute b) {
return cast<StringAttr>(a).getValue() < cast<StringAttr>(b).getValue();
});

return ArrayAttr::get(cls.getContext(), fieldNames);
}
9 changes: 9 additions & 0 deletions test/CAPI/om.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ void testEvaluator(MlirContext ctx) {

OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName);

MlirAttribute fieldNamesO = omEvaluatorObjectGetFieldNames(object);
// CHECK: ["child", "field"]
mlirAttributeDump(fieldNamesO);

OMObject child = omEvaluatorObjectValueGetObject(childField);

// CHECK: 0
Expand All @@ -115,6 +119,11 @@ void testEvaluator(MlirContext ctx) {
OMObjectValue foo = omEvaluatorObjectGetField(
child, mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo")));

MlirAttribute fieldNamesC = omEvaluatorObjectGetFieldNames(child);

// CHECK: ["foo"]
mlirAttributeDump(fieldNamesC);

// CHECK: child object field is primitive: 1
fprintf(stderr, "child object field is primitive: %d\n",
omEvaluatorObjectValueIsAPrimitive(foo));
Expand Down
13 changes: 13 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "gtest/gtest.h"
#include <mlir/IR/BuiltinAttributes.h>
Expand Down Expand Up @@ -378,6 +381,16 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) {
auto field2Value = std::get<std::shared_ptr<Object>>(
result.value()->getField(builder.getStringAttr("field2")).value());

auto fieldNames = result.value()->getFieldNames();

ASSERT_TRUE(fieldNames.size() == 2);
StringRef fieldNamesTruth[] = {"field1", "field2"};
for (auto fieldName : llvm::enumerate(fieldNames)) {
auto str = llvm::dyn_cast_or_null<StringAttr>(fieldName.value());
ASSERT_TRUE(str);
ASSERT_EQ(str.getValue(), fieldNamesTruth[fieldName.index()]);
}

ASSERT_TRUE(field1Value);
ASSERT_TRUE(field2Value);

Expand Down