Skip to content

Commit 4a9df48

Browse files
authored
[MLIR][Python] restore APIs in terms of Mlir* types (#160203)
#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see #160183 (comment)). This PR restores those APIs.
1 parent 42b195e commit 4a9df48

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4283,6 +4283,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
42834283
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
42844284
},
42854285
kValueReplaceAllUsesWithDocstring)
4286+
.def(
4287+
"replace_all_uses_except",
4288+
[](MlirValue self, MlirValue with, PyOperation &exception) {
4289+
MlirOperation exceptedUser = exception.get();
4290+
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
4291+
},
4292+
nb::arg("with_"), nb::arg("exceptions"),
4293+
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
4294+
"Operation) -> None"),
4295+
kValueReplaceAllUsesExceptDocstring)
4296+
.def(
4297+
"replace_all_uses_except",
4298+
[](MlirValue self, MlirValue with, nb::list exceptions) {
4299+
// Convert Python list to a SmallVector of MlirOperations
4300+
llvm::SmallVector<MlirOperation> exceptionOps;
4301+
for (nb::handle exception : exceptions) {
4302+
exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
4303+
}
4304+
4305+
mlirValueReplaceAllUsesExcept(
4306+
self, with, static_cast<intptr_t>(exceptionOps.size()),
4307+
exceptionOps.data());
4308+
},
4309+
nb::arg("with_"), nb::arg("exceptions"),
4310+
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
4311+
"Sequence[Operation]) -> None"),
4312+
kValueReplaceAllUsesExceptDocstring)
42864313
.def(
42874314
"replace_all_uses_except",
42884315
[](PyValue &self, PyValue &with, PyOperation &exception) {

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,18 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
898898
},
899899
nb::arg("elements"), nb::arg("context") = nb::none(),
900900
"Create a tuple type");
901+
c.def_static(
902+
"get_tuple",
903+
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
904+
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
905+
elements.data());
906+
return PyTupleType(context->getRef(), t);
907+
},
908+
nb::arg("elements"), nb::arg("context") = nb::none(),
909+
// clang-format off
910+
nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
911+
// clang-format on
912+
"Create a tuple type");
901913
c.def(
902914
"get_type",
903915
[](PyTupleType &self, intptr_t pos) {
@@ -944,6 +956,20 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
944956
},
945957
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
946958
"Gets a FunctionType from a list of input and result types");
959+
c.def_static(
960+
"get",
961+
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
962+
DefaultingPyMlirContext context) {
963+
MlirType t =
964+
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
965+
results.size(), results.data());
966+
return PyFunctionType(context->getRef(), t);
967+
},
968+
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
969+
// clang-format off
970+
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
971+
// clang-format on
972+
"Gets a FunctionType from a list of input and result types");
947973
c.def_prop_ro(
948974
"inputs",
949975
[](PyFunctionType &self) {

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
8383
//----------------------------------------------------------------------------
8484
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
8585
nb::class_<PyPDLPatternModule>(m, "PDLModule")
86+
.def(
87+
"__init__",
88+
[](PyPDLPatternModule &self, MlirModule module) {
89+
new (&self)
90+
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
91+
},
92+
// clang-format off
93+
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
94+
// clang-format on
95+
"module"_a, "Create a PDL module from the given module.")
8696
.def(
8797
"__init__",
8898
[](PyPDLPatternModule &self, PyModule &module) {
@@ -117,6 +127,22 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
117127
// clang-format on
118128
"Applys the given patterns to the given module greedily while folding "
119129
"results.")
130+
.def(
131+
"apply_patterns_and_fold_greedily",
132+
[](PyModule &module, MlirFrozenRewritePatternSet set) {
133+
auto status =
134+
mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
135+
if (mlirLogicalResultIsFailure(status))
136+
throw std::runtime_error(
137+
"pattern application failed to converge");
138+
},
139+
"module"_a, "set"_a,
140+
// clang-format off
141+
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
142+
// clang-format on
143+
"Applys the given patterns to the given module greedily while "
144+
"folding "
145+
"results.")
120146
.def(
121147
"apply_patterns_and_fold_greedily",
122148
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
@@ -131,5 +157,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
131157
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
132158
// clang-format on
133159
"Applys the given patterns to the given op greedily while folding "
160+
"results.")
161+
.def(
162+
"apply_patterns_and_fold_greedily",
163+
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
164+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
165+
op.getOperation(), set, {});
166+
if (mlirLogicalResultIsFailure(status))
167+
throw std::runtime_error(
168+
"pattern application failed to converge");
169+
},
170+
"op"_a, "set"_a,
171+
// clang-format off
172+
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
173+
// clang-format on
174+
"Applys the given patterns to the given op greedily while folding "
134175
"results.");
135176
}

0 commit comments

Comments
 (0)