diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 61d3446317550..374d2fb78de88 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op); MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( + MlirOperation op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 0373f9c7affe9..5b7de50f02e6a 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -99,14 +99,25 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( - "apply_patterns_and_fold_greedily", - [](MlirModule module, MlirFrozenRewritePatternSet set) { - auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); - if (mlirLogicalResultIsFailure(status)) - // FIXME: Not sure this is the right error to throw here. - throw nb::value_error("pattern application failed to converge"); - }, - "module"_a, "set"_a, - "Applys the given patterns to the given module greedily while folding " - "results."); + "apply_patterns_and_fold_greedily", + [](PyModule &module, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedily(module.get(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + "Applys the given patterns to the given module greedily while folding " + "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyOperationBase &op, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedilyWithOp( + op.getOperation(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "op"_a, "set"_a, + "Applys the given patterns to the given op greedily while folding " + "results."); } diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index a4df97f7beace..6f85357a14a18 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op, return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } +MlirLogicalResult +mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/integration/dialects/pdl.py b/mlir/test/python/integration/dialects/pdl.py index 923af29a71ad7..dd6c74ce622c8 100644 --- a/mlir/test/python/integration/dialects/pdl.py +++ b/mlir/test/python/integration/dialects/pdl.py @@ -16,20 +16,7 @@ def construct_and_print_in_module(f): print(module) return f - -# CHECK-LABEL: TEST: test_add_to_mul -# CHECK: arith.muli -@construct_and_print_in_module -def test_add_to_mul(module_): - index_type = IndexType.get() - - # Create a test case. - @module(sym_name="ir") - def ir(): - @func.func(index_type, index_type) - def add_func(a, b): - return arith.addi(a, b) - +def get_pdl_patterns(): # Create a rewrite from add to mul. This will match # - operation name is arith.addi # - operands are index types. @@ -61,7 +48,41 @@ def rew(): # not yet captured Python side/has sharp edges. So best to construct the # module and PDL module in same scope. # FIXME: This should be made more robust. - frozen = PDLModule(m).freeze() + return PDLModule(m).freeze() + + +# CHECK-LABEL: TEST: test_add_to_mul +# CHECK: arith.muli +@construct_and_print_in_module +def test_add_to_mul(module_): + index_type = IndexType.get() + + # Create a test case. + @module(sym_name="ir") + def ir(): + @func.func(index_type, index_type) + def add_func(a, b): + return arith.addi(a, b) + + frozen = get_pdl_patterns() # Could apply frozen pattern set multiple times. apply_patterns_and_fold_greedily(module_, frozen) return module_ + + +# CHECK-LABEL: TEST: test_add_to_mul_with_op +# CHECK: arith.muli +@construct_and_print_in_module +def test_add_to_mul_with_op(module_): + index_type = IndexType.get() + + # Create a test case. + @module(sym_name="ir") + def ir(): + @func.func(index_type, index_type) + def add_func(a, b): + return arith.addi(a, b) + + frozen = get_pdl_patterns() + apply_patterns_and_fold_greedily(module_.operation, frozen) + return module_