Skip to content

Commit

Permalink
Add Operation to python bindings.
Browse files Browse the repository at this point in the history
* Fixes a rather egregious bug with respect to the inability to return arbitrary objects from py::init (was causing aliasing of multiple py::object -> native instance).
* Makes Modules and Operations referencable types so that they can be reliably depended on.
* Uniques python operation instances within a context. Opens the door for further accounting.
* Next I will retrofit region and block to be dependent on the operation, and I will attempt to model the API to avoid detached regions/blocks, which will simplify things a lot (in that world, only operations can be detached).
* Added quite a bit of test coverage to check for leaks and reference issues.
* Supercedes: https://reviews.llvm.org/D87213

Differential Revision: https://reviews.llvm.org/D87958
  • Loading branch information
stellaraccident committed Sep 23, 2020
1 parent bd8b50c commit 7abb0ff
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 58 deletions.
22 changes: 22 additions & 0 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ issues that arise when combining RTTI-based modules (which pybind derived things
are) with non-RTTI polymorphic C++ code (the default compilation mode of LLVM).


### Ownership in the Core IR

There are several top-level types in the core IR that are strongly owned by their python-side reference:

* `PyContext` (`mlir.ir.Context`)
* `PyModule` (`mlir.ir.Module`)
* `PyOperation` (`mlir.ir.Operation`) - but with caveats

All other objects are dependent. All objects maintain a back-reference (keep-alive) to their closest containing top-level object. Further, dependent objects fall into two categories: a) uniqued (which live for the life-time of the context) and b) mutable. Mutable objects need additional machinery for keeping track of when the C++ instance that backs their Python object is no longer valid (typically due to some specific mutation of the IR, deletion, or bulk operation).

#### Operation hierarchy

As mentioned above, `PyOperation` is special because it can exist in either a top-level or dependent state. The life-cycle is unidirectional: operations can be created detached (top-level) and once added to another operation, they are then dependent for the remainder of their lifetime. The situation is more complicated when considering construction scenarios where an operation is added to a transitive parent that is still detached, necessitating further accounting at such transition points (i.e. all such added children are initially added to the IR with a parent of their outer-most detached operation, but then once it is added to an attached operation, they need to be re-parented to the containing module).

Due to the validity and parenting accounting needs, `PyOperation` is the owner for regions and blocks and needs to be a top-level type that we can count on not aliasing. This let's us do things like selectively invalidating instances when mutations occur without worrying that there is some alias to the same operation in the hierarchy. Operations are also the only entity that are allowed to be in a detached state, and they are interned at the context level so that there is never more than one Python `mlir.ir.Operation` object for a unique `MlirOperation`, regardless of how it is obtained.

The C/C++ API allows for Region/Block to also be detached, but it simplifies the ownership model a lot to eliminate that possibility in this API, allowing the Region/Block to be completely dependent on its owning operation for accounting. The aliasing of Python `Region`/`Block` instances to underlying `MlirRegion`/`MlirBlock` is considered benign and these objects are not interned in the context (unlike operations).

If we ever want to re-introduce detached regions/blocks, we could do so with new "DetachedRegion" class or similar and also avoid the complexity of accounting. With the way it is now, we can avoid having a global live list for regions and blocks. We may end up needing an op-local one at some point TBD, depending on how hard it is to guarantee how mutations interact with their Python peer objects. We can cross that bridge easily when we get there.

Module, when used purely from the Python API, can't alias anyway, so we can use it as a top-level ref type without a live-list for interning. If the API ever changes such that this cannot be guaranteed (i.e. by letting you marshal a native-defined Module in), then there would need to be a live table for it too.

## Style

In general, for the core parts of MLIR, the Python bindings should be largely
Expand Down
178 changes: 152 additions & 26 deletions mlir/lib/Bindings/Python/IRModules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
// PyMlirContext
//------------------------------------------------------------------------------

PyMlirContext *PyMlirContextRef::release() {
object.release();
return &referrent;
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
py::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}

PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}

PyMlirContext::~PyMlirContext() {
// Note that the only public way to construct an instance is via the
// forContext method, which always puts the associated handle into
Expand All @@ -190,6 +189,11 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}

PyMlirContext *PyMlirContext::createNewContextForInit() {
MlirContext context = mlirContextCreate();
return new PyMlirContext(context);
}

PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
py::gil_scoped_acquire acquire;
auto &liveContexts = getLiveContexts();
Expand All @@ -198,14 +202,13 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
// Create.
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
py::object pyRef = py::cast(unownedContextWrapper);
unownedContextWrapper->handle = pyRef;
liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
} else {
// Use existing.
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
return PyMlirContextRef(*it->second.second, std::move(pyRef));
assert(pyRef && "cast to py::object failed");
liveContexts[context.ptr] = unownedContextWrapper;
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
}
// Use existing.
py::object pyRef = py::cast(it->second);
return PyMlirContextRef(it->second, std::move(pyRef));
}

PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
Expand All @@ -215,8 +218,99 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {

size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }

size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }

//------------------------------------------------------------------------------
// PyModule
//------------------------------------------------------------------------------

PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
py::object pyRef =
py::cast(unownedModule, py::return_value_policy::take_ownership);
unownedModule->handle = pyRef;
return PyModuleRef(unownedModule, std::move(pyRef));
}

//------------------------------------------------------------------------------
// PyOperation
//------------------------------------------------------------------------------

PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
: BaseContextObject(std::move(contextRef)), operation(operation) {}

PyOperation::~PyOperation() {
auto &liveOperations = getContext()->liveOperations;
assert(liveOperations.count(operation.ptr) == 1 &&
"destroying operation not in live map");
liveOperations.erase(operation.ptr);
if (!isAttached()) {
mlirOperationDestroy(operation);
}
}

PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
// Create.
PyOperation *unownedOperation =
new PyOperation(std::move(contextRef), operation);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
py::object pyRef =
py::cast(unownedOperation, py::return_value_policy::take_ownership);
unownedOperation->handle = pyRef;
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
return PyOperationRef(unownedOperation, std::move(pyRef));
}

PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
// Create.
return createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
}
// Use existing.
PyOperation *existing = it->second.second;
assert(existing->parentKeepAlive.is(parentKeepAlive));
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
return PyOperationRef(existing, std::move(pyRef));
}

PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
auto &liveOperations = contextRef->liveOperations;
assert(liveOperations.count(operation.ptr) == 0 &&
"cannot create detached operation that already exists");
(void)liveOperations;

PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
created->attached = false;
return created;
}

void PyOperation::checkValid() {
if (!valid) {
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
}
}

//------------------------------------------------------------------------------
// PyBlock, PyRegion, and PyOperation.
// PyBlock, PyRegion.
//------------------------------------------------------------------------------

void PyRegion::attachToParent() {
Expand Down Expand Up @@ -865,29 +959,27 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of MlirContext
py::class_<PyMlirContext>(m, "Context")
.def(py::init<>([]() {
MlirContext context = mlirContextCreate();
auto contextRef = PyMlirContext::forContext(context);
return contextRef.release();
}))
.def(py::init<>(&PyMlirContext::createNewContextForInit))
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
[](PyMlirContext &self) {
auto ref = PyMlirContext::forContext(self.get());
return ref.release();
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def(
"parse_module",
[](PyMlirContext &self, const std::string module) {
auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
[](PyMlirContext &self, const std::string moduleAsm) {
MlirModule module =
mlirModuleCreateParse(self.get(), moduleAsm.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(moduleRef)) {
if (mlirModuleIsNull(module)) {
throw SetPyError(
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
return PyModule(self.getRef(), moduleRef);
return PyModule::create(self.getRef(), module).releaseObject();
},
kContextParseDocstring)
.def(
Expand Down Expand Up @@ -975,23 +1067,57 @@ void mlir::python::populateIRSubmodule(py::module &m) {

// Mapping of Module
py::class_<PyModule>(m, "Module")
.def_property_readonly(
"operation",
[](PyModule &self) {
return PyOperation::forOperation(self.getContext(),
mlirModuleGetOperation(self.get()),
self.getRef().releaseObject())
.releaseObject();
},
"Accesses the module as an operation")
.def(
"dump",
[](PyModule &self) {
mlirOperationDump(mlirModuleGetOperation(self.module));
mlirOperationDump(mlirModuleGetOperation(self.get()));
},
kDumpDocstring)
.def(
"__str__",
[](PyModule &self) {
auto operation = mlirModuleGetOperation(self.module);
MlirOperation operation = mlirModuleGetOperation(self.get());
PyPrintAccumulator printAccum;
mlirOperationPrint(operation, printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
kOperationStrDunderDocstring);

// Mapping of Operation.
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
"first_region",
[](PyOperation &self) {
self.checkValid();
if (mlirOperationGetNumRegions(self.get()) == 0) {
throw SetPyError(PyExc_IndexError, "Operation has no regions");
}
return PyRegion(self.getContext()->get(),
mlirOperationGetRegion(self.get(), 0),
/*detached=*/false);
},
py::keep_alive<0, 1>(), "Gets the operation's first region")
.def(
"__str__",
[](PyOperation &self) {
self.checkValid();
PyPrintAccumulator printAccum;
mlirOperationPrint(self.get(), printAccum.getCallback(),
printAccum.getUserData());
return printAccum.join();
},
kTypeStrDunderDocstring);

// Mapping of PyRegion.
py::class_<PyRegion>(m, "Region")
.def(
Expand Down
Loading

0 comments on commit 7abb0ff

Please sign in to comment.