Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4283,6 +4283,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
nb::arg("with_"), nb::arg("exceptions"),
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
"Operation) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, nb::list exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
for (nb::handle exception : exceptions) {
exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
}

mlirValueReplaceAllUsesExcept(
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
nb::arg("with_"), nb::arg("exceptions"),
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
"Sequence[Operation]) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, PyOperation &exception) {
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,18 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
c.def_static(
"get_tuple",
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
elements.data());
return PyTupleType(context->getRef(), t);
},
nb::arg("elements"), nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
// clang-format on
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
Expand Down Expand Up @@ -944,6 +956,20 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
"Gets a FunctionType from a list of input and result types");
c.def_static(
"get",
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
MlirType t =
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
// clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
[](PyFunctionType &self) {
Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Bindings/Python/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, MlirModule module) {
new (&self)
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
Expand Down Expand Up @@ -117,6 +127,22 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, MlirFrozenRewritePatternSet set) {
auto status =
mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"module"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given module greedily while "
"folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
Expand All @@ -131,5 +157,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set, {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.");
}