diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h index 121a7222222d3..d3220ceb43f67 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -59,7 +59,7 @@ using DistributionMapFn = std::function; /// } /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> void populateDistributeTransferWriteOpPatterns( - RewritePatternSet &patterns, DistributionMapFn distributionMapFn); + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn); /// Move scalar operations with no dependency on the warp op outside of the /// region. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e8602af2a9e56..08eced2bd935e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -15,6 +15,8 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/SideEffectUtils.h" +#include + using namespace mlir; using namespace mlir::vector; @@ -281,7 +283,7 @@ struct WarpOpTransferWrite : public OpRewritePattern { WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) : OpRewritePattern(ctx, b), - distributionMapFn(fn) {} + distributionMapFn(std::move(fn)) {} /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that /// are multiples of the distribution ratio are supported at the moment. @@ -815,7 +817,7 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( } void mlir::vector::populateDistributeTransferWriteOpPatterns( - RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { patterns.add(patterns.getContext(), distributionMapFn); }