diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 44da2965e6892..7cd8da6643400 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -145,7 +145,10 @@ def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> { let options = [ Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool", /*default=*/"false", - "Generate rank-reducing slices instead of reassociative reshapes"> + "Generate rank-reducing slices instead of reassociative reshapes">, + Option<"enableMoveInitOperandsToInput", "enable-move-init-operands-to-input", "bool", + /*default=*/"true", + "Enable MoveInitOperandsToInputPattern transformation"> ]; let dependentDialects = [ "linalg::LinalgDialect", "affine::AffineDialect", "memref::MemRefDialect" diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 22690daa4f9e1..ba466ed3df5cd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -868,7 +868,9 @@ struct LinalgFoldUnitExtentDimsPass RankReductionStrategy::ExtractInsertSlice; } linalg::populateFoldUnitExtentDimsPatterns(patterns, options); - populateMoveInitOperandsToInputPattern(patterns); + if (enableMoveInitOperandsToInput) { + populateMoveInitOperandsToInputPattern(patterns); + } (void)applyPatternsGreedily(op, std::move(patterns)); } };