-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] restore APIs in terms of Mlir* types #160203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] restore APIs in terms of Mlir* types #160203
Conversation
6d2d416
to
6646380
Compare
6646380
to
a3f99e8
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) Changes#157930 changed a few APIs from Full diff: https://github.com/llvm/llvm-project/pull/160203.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 81386f2227a7f..4b238e11c7fff 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -4283,6 +4283,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](MlirValue self, MlirValue with, PyOperation &exception) {
+ MlirOperation exceptedUser = exception.get();
+ mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
+ "Operation) -> None"),
+ kValueReplaceAllUsesExceptDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](MlirValue self, MlirValue with, nb::list exceptions) {
+ // Convert Python list to a SmallVector of MlirOperations
+ llvm::SmallVector<MlirOperation> exceptionOps;
+ for (nb::handle exception : exceptions) {
+ exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
+ }
+
+ mlirValueReplaceAllUsesExcept(
+ self, with, static_cast<intptr_t>(exceptionOps.size()),
+ exceptionOps.data());
+ },
+ nb::arg("with_"), nb::arg("exceptions"),
+ nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
+ "Sequence[Operation]) -> None"),
+ kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, PyOperation &exception) {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7aa1c65c6c43..cab3bf549295b 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -898,6 +898,18 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
+ // clang-format on
+ "Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
@@ -944,6 +956,20 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
"Gets a FunctionType from a list of input and result types");
+ c.def_static(
+ "get",
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+ DefaultingPyMlirContext context) {
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
+ // clang-format on
+ "Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
[](PyFunctionType &self) {
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 96f00cface64f..3476793369907 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -83,6 +83,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyPDLPatternModule>(m, "PDLModule")
+ .def(
+ "__init__",
+ [](PyPDLPatternModule &self, MlirModule module) {
+ new (&self)
+ PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
+ },
+ // clang-format off
+ nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
+ // clang-format on
+ "module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
@@ -117,6 +127,22 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
+ .def(
+ "apply_patterns_and_fold_greedily",
+ [](PyModule &module, MlirFrozenRewritePatternSet set) {
+ auto status =
+ mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ },
+ "module"_a, "set"_a,
+ // clang-format off
+ nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
+ // clang-format on
+ "Applys the given patterns to the given module greedily while "
+ "folding "
+ "results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
@@ -131,5 +157,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
+ "results.")
+ .def(
+ "apply_patterns_and_fold_greedily",
+ [](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
+ auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
+ op.getOperation(), set, {});
+ if (mlirLogicalResultIsFailure(status))
+ throw std::runtime_error(
+ "pattern application failed to converge");
+ },
+ "op"_a, "set"_a,
+ // clang-format off
+ nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
+ // clang-format on
+ "Applys the given patterns to the given op greedily while folding "
"results.");
}
|
@Ahajha if this PR fixes you, I'll merge. |
Yup, I think things are mostly fixed on my end, hopefully the one remaining issue is within our codebase. |
Thanks for the quick followup! |
llvm#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see llvm#160183 (comment)). This PR restores those APIs.
llvm#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see llvm#160183 (comment)). This PR restores those APIs.
…8738) LLVM has also moved to generated stubfiles, so we need to generate these ourselves (we could omit them, but they're nice to have). See llvm/llvm-project#157930. Also pulls in the following followup fixes, these should be removed when bumping again. llvm/llvm-project#160183 llvm/llvm-project#160203 llvm/llvm-project#160221 This fixes some mypy lint errors, and causes a few more, I fixed a few but mostly just ignored them. MODULAR_ORIG_COMMIT_REV_ID: 524aaf2ab047e5185703c44ab3edd7754c67fa26
#157930 changed a few APIs from
Mlir*
toPy*
and broke users that were using them (see #160183 (comment)). This PR restores those APIs.