diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8273a9346e5dd..10360e448858c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() { PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} -PyModule::~PyModule() { mlirModuleDestroy(module); } +PyModule::~PyModule() { + nb::gil_scoped_acquire acquire; + auto &liveModules = getContext()->liveModules; + assert(liveModules.count(module.ptr) == 1 && + "destroying module not in live map"); + liveModules.erase(module.ptr); + mlirModuleDestroy(module); +} PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is `automatic_reference`, - // which means "does not take ownership, does not call delete/dtor". - // We use `take_ownership`, which means "Python will call the C++ destructor - // and delete operator when the Python wrapper is garbage collected", because - // MlirModule actually wraps OwningOpRef (see mlirModuleCreateParse - // etc). - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); - unownedModule->handle = pyRef; - return PyModuleRef(unownedModule, std::move(pyRef)); + nb::gil_scoped_acquire acquire; + auto &liveModules = contextRef->liveModules; + auto it = liveModules.find(module.ptr); + if (it == liveModules.end()) { + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + liveModules[module.ptr] = + std::make_pair(unownedModule->handle, unownedModule); + return PyModuleRef(unownedModule, std::move(pyRef)); + } + // Use existing. + PyModule *existing = it->second.second; + nb::object pyRef = nb::borrow(it->second.first); + return PyModuleRef(existing, std::move(pyRef)); } nb::object PyModule::createFromCapsule(nb::object capsule) { @@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) { return PyInsertionPoint{block, std::move(nextOpRef)}; } +size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } + nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { return PyThreadContextEntry::pushInsertionPoint(insertPoint); } @@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 1d1ff29533f98..28b885f136fe0 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,6 +218,10 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); + /// Gets the count of live modules associated with this context. + /// Used for testing. + size_t getLiveModuleCount(); + /// Enter and exit the context manager. static nanobind::object contextEnter(nanobind::object context); void contextExit(const nanobind::object &excType, @@ -244,6 +248,14 @@ class PyMlirContext { static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); + // Interns all live modules associated with this context. Modules tracked + // in this map are valid. When a module is invalidated, it is removed + // from this map, and while it still exists as an instance, any + // attempt to access it will raise an error. + using LiveModuleMap = + llvm::DenseMap>; + LiveModuleMap liveModules; + bool emitErrorDiagnostics = false; MlirContext context; diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index ad4c9340a6c82..33959bea9ffb6 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -121,6 +121,7 @@ def testRoundtripBinary(): def testModuleOperation(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) + assert ctx._get_live_module_count() == 1 op1 = module.operation # CHECK: module @successfulParse print(op1) @@ -145,6 +146,7 @@ def testModuleOperation(): op1 = None op2 = None gc.collect() + assert ctx._get_live_module_count() == 0 # CHECK-LABEL: TEST: testModuleCapsule @@ -152,17 +154,17 @@ def testModuleOperation(): def testModuleCapsule(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) + assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" module_capsule = module._CAPIPtr print(module_capsule) module_dup = Module._CAPICreate(module_capsule) - assert module is not module_dup + assert module is module_dup assert module == module_dup - module._clear_mlir_module() - assert module != module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None module_capsule = None module_dup = None gc.collect() + assert ctx._get_live_module_count() == 0