diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 8cd51edf0837b..9c96d354d4fc9 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -18,9 +18,19 @@ #include "mlir-c/Rewrite.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) DEFINE_C_API_PTR_METHODS(MlirRewritePattern, const mlir::RewritePattern) DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, mlir::RewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirFrozenRewritePatternSet, + mlir::FrozenRewritePatternSet) +DEFINE_C_API_PTR_METHODS(MlirPatternRewriter, mlir::PatternRewriter) + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, mlir::PDLPatternModule) +DEFINE_C_API_PTR_METHODS(MlirPDLResultList, mlir::PDLResultList) +DEFINE_C_API_PTR_METHODS(MlirPDLValue, const mlir::PDLValue) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index d506b7fc9bc7b..47685567d5355 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -261,7 +261,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- - nb::class_(m, "RewritePattern"); nb::class_(m, "RewritePatternSet") .def( "__init__", diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 70dee598c9535..46c329d8433b4 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -270,17 +270,6 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -static inline mlir::FrozenRewritePatternSet * -unwrap(MlirFrozenRewritePatternSet module) { - assert(module.ptr && "unexpected null module"); - return static_cast(module.ptr); -} - -static inline MlirFrozenRewritePatternSet -wrap(mlir::FrozenRewritePatternSet *module) { - return {module}; -} - MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set) { auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); @@ -311,15 +300,6 @@ mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, /// PatternRewriter API //===----------------------------------------------------------------------===// -inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { - assert(rewriter.ptr && "unexpected null rewriter"); - return static_cast(rewriter.ptr); -} - -inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { - return {rewriter}; -} - MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast(unwrap(rewriter))); } @@ -400,15 +380,6 @@ void mlirRewritePatternSetAdd(MlirRewritePatternSet set, //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { - assert(module.ptr && "unexpected null module"); - return static_cast(module.ptr); -} - -static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { - return {module}; -} - MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef(unwrap(op)))); @@ -426,22 +397,6 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { return wrap(m); } -inline const mlir::PDLValue *unwrap(MlirPDLValue value) { - assert(value.ptr && "unexpected null PDL value"); - return static_cast(value.ptr); -} - -inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } - -inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { - assert(results.ptr && "unexpected null PDL results"); - return static_cast(results.ptr); -} - -inline MlirPDLResultList wrap(mlir::PDLResultList *results) { - return {results}; -} - MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast()); }