diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7d2fd89e8560f..14ccae650606a 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Parses a module from file and transfers ownership to the caller. +MLIR_CAPI_EXPORTED MlirModule +mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c83d3e2f..bc8bddf4caf7e 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -23,6 +23,7 @@ #endif #include #include +#include #include #include #include diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8e351cb22eb94..b772c9a583a6b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include @@ -299,7 +300,7 @@ struct PyAttrBuilderMap { return *builder; } static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } @@ -3047,6 +3048,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) + .def_static( + "parse", + [](const std::filesystem::path &path, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParseFromFile( + context->get(), toMlirStringRef(path.string())); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) .def_static( "create", [](DefaultingPyLocation loc) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index f27af0ca9a2c7..999e8cbda1295 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { return MlirModule{owning.release().getOperation()}; } +MlirModule mlirModuleCreateParseFromFile(MlirContext context, + MlirStringRef fileName) { + OwningOpRef owning = + parseSourceFile(unwrap(fileName), unwrap(context)); + if (!owning) + return MlirModule{nullptr}; + return MlirModule{owning.release().getOperation()}; +} + MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index fb7efb8cd28a5..096b87b362443 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -46,6 +46,7 @@ import abc import collections from collections.abc import Callable, Sequence import io +from pathlib import Path from typing import Any, ClassVar, TypeVar, overload __all__ = [ @@ -2123,7 +2124,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Module: + def parse(asm: str | bytes | Path, context: Context | None = None) -> Module: """ Parses a module's assembly format from a string. diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index ecafcb46af217..441916b38ee73 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -1,6 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s import gc +from pathlib import Path +from tempfile import NamedTemporaryFile from mlir.ir import * @@ -27,6 +29,24 @@ def testParseSuccess(): print(str(module)) +# Verify successful parse from file. +# CHECK-LABEL: TEST: testParseFromFileSuccess +# CHECK: module @successfulParse +@run +def testParseFromFileSuccess(): + ctx = Context() + with NamedTemporaryFile(mode="w") as tmp_file: + tmp_file.write(r"""module @successfulParse {}""") + tmp_file.flush() + module = Module.parse(Path(tmp_file.name), ctx) + assert module.context is ctx + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() + module.operation.verify() + print(str(module)) + + # Verify parse error. # CHECK-LABEL: TEST: testParseError # CHECK: testParseError: <