Skip to content

Commit

Permalink
[OM] Simplify the Python instantiate API to just return Objects. (llv…
Browse files Browse the repository at this point in the history
…m#5400)

This API previously accepted a Python dataclass as input, and used it
as a template to guide the logic to pull fields out of the
instantiated Object and build an instance of the requested dataclass.

This allowed the API to be strongly typed, by accepting a dataclass
type as input and returning an instance of that dataclass as
output. However, this requires the user to specify a-priori what
fields should be present in the resulting Object, and this may not
always be known.

By simply returning an Object, and adding the appropriate Python
conversions to Object's __getattr__, we can generically return Objects
that behave just like the previous dataclasses, without specifying the
fields up front. This also avoids rewrapping the Objects in
dataclasses.

In the future, we can add back a similar form of type safety when this
would be useful, potentially using Protocols similarly to how
dataclasses were used as a template before.
  • Loading branch information
mikeurbach committed Jun 14, 2023
1 parent 49f73eb commit c5aff9a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 73 deletions.
37 changes: 8 additions & 29 deletions integration_test/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,8 @@

# Test instantiate failure.


@dataclass
class Test:
field: int


try:
obj = evaluator.instantiate(Test)
obj = evaluator.instantiate("Test")
except ValueError as e:
# CHECK: actual parameter list length (0) does not match
# CHECK: actual parameters:
Expand All @@ -47,14 +41,9 @@ class Test:

# Test get field failure.


@dataclass
class Test:
foo: int


try:
obj = evaluator.instantiate(Test, 42)
obj = evaluator.instantiate("Test", 42)
obj.foo
except ValueError as e:
# CHECK: field "foo" does not exist
# CHECK: see current operation:
Expand All @@ -63,19 +52,9 @@ class Test:

# Test instantiate success.

obj = evaluator.instantiate("Test", 42)

@dataclass
class Child:
foo: int


@dataclass
class Test:
field: int
child: Child


obj = evaluator.instantiate(Test, 42)

# CHECK: Test(field=42, child=Child(foo=14))
print(obj)
# CHECK: 42
print(obj.field)
# CHECK: 14
print(obj.child.foo)
14 changes: 11 additions & 3 deletions lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "DialectModules.h"
#include "circt-c/Dialect/OM.h"
#include "circt/Support/LLVM.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include <pybind11/pybind11.h>
Expand All @@ -23,15 +24,21 @@ using namespace mlir::python::adaptors;
struct Object {
// Instantiate an Object with a reference to the underlying OMObject.
Object(OMObject object) : object(object) {}
Object(const Object &object) : object(object.object) {}

/// Get the Type from an Object, which will be a ClassType.
MlirType getType() { return omEvaluatorObjectGetType(object); }

// Get a field from the Object, using pybind's support for variant to return a
// Python object that is either an Object or Attribute.
std::variant<Object, MlirAttribute> getField(MlirAttribute name) {
std::variant<Object, MlirAttribute> getField(const std::string &name) {
// Wrap the requested field name in an attribute.
MlirContext context = mlirTypeGetContext(omEvaluatorObjectGetType(object));
MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
MlirAttribute nameAttr = mlirStringAttrGet(context, cName);

// Get the field's ObjectValue via the CAPI.
OMObjectValue result = omEvaluatorObjectGetField(object, name);
OMObjectValue result = omEvaluatorObjectGetField(object, nameAttr);

// If the ObjectValue is null, something failed. Diagnostic handling is
// implemented in pure Python, so nothing to do here besides throwing an
Expand Down Expand Up @@ -98,7 +105,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) {

// Add the Object class definition.
py::class_<Object>(m, "Object")
.def("get_field", &Object::getField, "Get a field from an Object",
.def(py::init<Object>(), py::arg("object"))
.def("__getattr__", &Object::getField, "Get a field from an Object",
py::arg("name"))
.def_property_readonly("type", &Object::getType,
"The Type of the Object");
Expand Down
67 changes: 26 additions & 41 deletions lib/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from ._om_ops_gen import *
from .._mlir_libs._circt._om import Evaluator as BaseEvaluator, Object, ClassType
from .._mlir_libs._circt._om import Evaluator as BaseEvaluator, Object as BaseObject, ClassType

from circt.ir import Attribute, Diagnostic, DiagnosticSeverity, Module, StringAttr
from circt.support import attribute_to_var, var_to_attribute
Expand All @@ -19,6 +19,25 @@
from _typeshed.stdlib.dataclass import DataclassInstance


# Define the Object class by inheriting from the base implementation in C++.
class Object(BaseObject):

def __init__(self, obj: BaseObject) -> None:
super().__init__(obj)

def __getattr__(self, name: str):
# Call the base method to get a field.
field = super().__getattr__(name)

# For primitives, return a Python value.
if isinstance(field, Attribute):
return attribute_to_var(field)

# For objects, return an Object, wrapping the base implementation.
assert isinstance(field, BaseObject)
return Object(field)


# Define the Evaluator class by inheriting from the base implementation in C++.
class Evaluator(BaseEvaluator):

Expand All @@ -40,15 +59,14 @@ def __init__(self, mod: Module) -> None:
# Attach our Diagnostic handler.
mod.context.attach_diagnostic_handler(self._handle_diagnostic)

def instantiate(self, cls: type["DataclassInstance"],
*args: Any) -> "DataclassInstance":
"""Instantiate an Object with a dataclass type and actual parameters."""
def instantiate(self, cls: str, *args: Any) -> Object:
"""Instantiate an Object with a class name and actual parameters."""

# Convert the class name and actual parameters to Attributes within the
# Evaluator's context.
with self.module.context:
# Get the class name from the provided dataclass name.
class_name = StringAttr.get(cls.__name__)
# Get the class name from the class name.
class_name = StringAttr.get(cls)

# Get the actual parameter Attributes from the supplied variadic
# arguments. This relies on the circt.support helpers to convert from
Expand All @@ -58,41 +76,8 @@ def instantiate(self, cls: type["DataclassInstance"],
# Call the base instantiate method.
obj = super().instantiate(class_name, actual_params)

# Wrap the Object in the provided dataclass.
return self._instantiate_dataclass(cls, obj)

def _instantiate_dataclass(self, cls: type["DataclassInstance"],
obj: Object) -> "DataclassInstance":
# Convert the field names of the class we are instantiating to StringAttrs
# within the Evaluator's context.
with self.module.context:
class_fields = [
(StringAttr.get(field.name), field.type) for field in fields(cls)
]

# Convert the instantiated Object fields to Python objects.
object_fields = {}

for field_name, field_type in class_fields:
# Get the field from the object.
field = obj.get_field(field_name)

# Handle primitives represented as Attributes and nested Objects.
if isinstance(field, Attribute):
# Convert the field value to a Python object. This relies on the
# circt.support helpers to convert from Attribute to Python objects.
field_value = attribute_to_var(field)
else:
# Convert the field value to a Python dataclass for the Object.
assert isinstance(field, Object)
field_value = self._instantiate_dataclass(field_type, field)

# Save this field in the keyword argument dictionary that will be passed
# to the dataclass constructor.
object_fields[field_name.value] = field_value

# Instantiate a Python object of the requested class.
return cls(**object_fields)
# Return the Object, wrapping the base implementation.
return Object(obj)

def _handle_diagnostic(self, diagnostic: Diagnostic) -> bool:
"""Handle MLIR Diagnostics by logging them."""
Expand Down

0 comments on commit c5aff9a

Please sign in to comment.