From 34187afd9f63c700fd7442f8d9040124ef6707dd Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Fri, 10 Nov 2023 18:52:06 -0500 Subject: [PATCH] [mlir] Add MemRefAliasOp folding for `vector.store` --- mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index b78c4510ff885..85878aff2701f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) { static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } +static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } + static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } static Value getMemRefOperand(vector::TransferWriteOp op) { @@ -557,6 +559,10 @@ LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( subViewOp.getDroppedDims())), op.getMask(), op.getInBoundsAttr()); }) + .Case([&](vector::StoreOp op) { + rewriter.replaceOpWithNewOp( + op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); + }) .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { rewriter.replaceOpWithNewOp( op, op.getSrc(), subViewOp.getSource(), sourceIndices, @@ -698,6 +704,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, LoadOpOfExpandShapeOpFolder, LoadOpOfExpandShapeOpFolder,