diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index b78c4510ff885..85878aff2701f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) { static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } +static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } + static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } static Value getMemRefOperand(vector::TransferWriteOp op) { @@ -557,6 +559,10 @@ LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( subViewOp.getDroppedDims())), op.getMask(), op.getInBoundsAttr()); }) + .Case([&](vector::StoreOp op) { + rewriter.replaceOpWithNewOp( + op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); + }) .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { rewriter.replaceOpWithNewOp( op, op.getSrc(), subViewOp.getSource(), sourceIndices, @@ -698,6 +704,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, LoadOpOfExpandShapeOpFolder, LoadOpOfExpandShapeOpFolder,