diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index 2b7f3da150cdf1..c7e331e856519d 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace mlir { +class GlobalCreator; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -31,6 +32,12 @@ std::unique_ptr createStdBufferizePass(); /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); +/// Add patterns to bufferize tensor constants into global memrefs to the given +/// pattern list. +void populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index b40e47c9441414..518405aabb49ff 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -81,6 +81,13 @@ class BufferizeTensorConstantOp : public OpConversionPattern { }; } // namespace +void mlir::populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(globalCreator, typeConverter, + patterns.getContext()); +} + namespace { struct TensorConstantBufferizePass : public TensorConstantBufferizeBase { @@ -94,7 +101,7 @@ struct TensorConstantBufferizePass ConversionTarget target(*context); target.addLegalDialect(); - patterns.add(globals, typeConverter, context); + populateTensorConstantBufferizePatterns(globals, typeConverter, patterns); target.addDynamicallyLegalOp( [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns))))