diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 0f0ed22c50fa9..f7557c3f7f768 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -277,15 +277,37 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { "add", [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, unsigned benefit) { - std::string opName = - nb::cast(root.attr("OPERATION_NAME")); + std::string opName; + if (root.is_type()) { + opName = nb::cast(root.attr("OPERATION_NAME")); + } else if (nb::isinstance(root)) { + opName = nb::cast(root); + } else { + throw nb::type_error( + "the root argument must be a type or a string"); + } self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, fn); }, "root"_a, "fn"_a, "benefit"_a = 1, - "Add a new rewrite pattern on the given root operation with the " - "callable as the matching and rewriting function and the given " - "benefit.") + // clang-format off + nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"), + // clang-format on + R"( + Add a new rewrite pattern on the specified root operation, using the provided callable + for matching and rewriting, and assign it the given benefit. + + Args: + root: The root operation to which this pattern applies. + This may be either an OpView subclass (e.g., ``arith.AddIOp``) or + an operation name string (e.g., ``"arith.addi"``). + fn: The callable to use for matching and rewriting, + which takes an operation and a pattern rewriter as arguments. + The match is considered successful iff the callable returns + a value where ``bool(value)`` is ``False`` (e.g. ``None``). + If possible, the operation is cast to its corresponding OpView subclass + before being passed to the callable. + benefit: The benefit of the pattern, defaulting to 1.)") .def("freeze", &PyRewritePatternSet::freeze, "Freeze the pattern set into a frozen one."); diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py index 821e47085a5bd..a6027161f29db 100644 --- a/mlir/test/python/rewrite.py +++ b/mlir/test/python/rewrite.py @@ -32,7 +32,7 @@ def constant_1_to_2(op, rewriter): with Context(): patterns = RewritePatternSet() patterns.add(arith.AddIOp, to_muli) - patterns.add(arith.ConstantOp, constant_1_to_2) + patterns.add("arith.constant", constant_1_to_2) frozen = patterns.freeze() module = ModuleOp.parse(