diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 81386f2227a7f..4b238e11c7fff 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -4283,6 +4283,33 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, kValueReplaceAllUsesWithDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, PyOperation &exception) { + MlirOperation exceptedUser = exception.get(); + mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); + }, + nb::arg("with_"), nb::arg("exceptions"), + nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " + "Operation) -> None"), + kValueReplaceAllUsesExceptDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, nb::list exceptions) { + // Convert Python list to a SmallVector of MlirOperations + llvm::SmallVector exceptionOps; + for (nb::handle exception : exceptions) { + exceptionOps.push_back(nb::cast(exception).get()); + } + + mlirValueReplaceAllUsesExcept( + self, with, static_cast(exceptionOps.size()), + exceptionOps.data()); + }, + nb::arg("with_"), nb::arg("exceptions"), + nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " + "Sequence[Operation]) -> None"), + kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, PyOperation &exception) { diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index a7aa1c65c6c43..cab3bf549295b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -898,6 +898,18 @@ class PyTupleType : public PyConcreteType { }, nb::arg("elements"), nb::arg("context") = nb::none(), "Create a tuple type"); + c.def_static( + "get_tuple", + [](std::vector elements, DefaultingPyMlirContext context) { + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + elements.data()); + return PyTupleType(context->getRef(), t); + }, + nb::arg("elements"), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"), + // clang-format on + "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { @@ -944,6 +956,20 @@ class PyFunctionType : public PyConcreteType { }, nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), "Gets a FunctionType from a list of input and result types"); + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); + return PyFunctionType(context->getRef(), t); + }, + nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"), + // clang-format on + "Gets a FunctionType from a list of input and result types"); c.def_prop_ro( "inputs", [](PyFunctionType &self) { diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 96f00cface64f..3476793369907 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -83,6 +83,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLModule") + .def( + "__init__", + [](PyPDLPatternModule &self, MlirModule module) { + new (&self) + PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + }, + // clang-format off + nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), + // clang-format on + "module"_a, "Create a PDL module from the given module.") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { @@ -117,6 +127,22 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { // clang-format on "Applys the given patterns to the given module greedily while folding " "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyModule &module, MlirFrozenRewritePatternSet set) { + auto status = + mlirApplyPatternsAndFoldGreedily(module.get(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "module"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given module greedily while " + "folding " + "results.") .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { @@ -131,5 +157,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), // clang-format on "Applys the given patterns to the given op greedily while folding " + "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyOperationBase &op, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedilyWithOp( + op.getOperation(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "op"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given op greedily while folding " "results."); }