From 5ec360c58961379f29e8a69cb98c352412329f77 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Thu, 25 May 2023 16:10:38 +0200 Subject: [PATCH] [mlir] Enable folding memref alias for`vector.load` This work enables folding memref alias pass for`vector.load` Reviewed By: qcolombet Differential Revision: https://reviews.llvm.org/D151447 --- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 7 +++++++ .../Dialect/MemRef/fold-memref-alias-ops.mlir | 15 +++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 5916d6489cbc8..1fee97a0dd747 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -173,6 +173,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) { return op.getSrcMemref(); } +static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } + static Value getMemRefOperand(vector::TransferWriteOp op) { return op.getSource(); } @@ -397,6 +399,10 @@ LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( rewriter.replaceOpWithNewOp( loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); }) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + op, op.getType(), subViewOp.getSource(), sourceIndices); + }) .Case([&](vector::TransferReadOp op) { rewriter.replaceOpWithNewOp( op, op.getVectorType(), subViewOp.getSource(), sourceIndices, @@ -668,6 +674,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, + LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 0e9df2969023e..acfddc366df16 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -621,3 +621,18 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index // CHECK: nvgpu.ldmatrix %[[ARG0]][%[[ARG1]], %[[ARG2]], %[[ARG3]]] {numTiles = 4 : i32, transpose = false} : memref<4x32x32xf16, 3> -> vector<4x2xf16> + +// ----- + +func.func @fold_vector_load( + %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> { + %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref> + %1 = vector.load %0[] : memref>, vector<12x32xf32> + return %1 : vector<12x32xf32> +} + +// CHECK: func @fold_vector_load +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>