Skip to content

Commit

Permalink
[mlir] Enable folding memref alias forvector.load
Browse files Browse the repository at this point in the history
This work enables  folding memref alias pass for`vector.load`

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D151447
  • Loading branch information
grypp committed May 25, 2023
1 parent 2b1678c commit 5ec360c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Expand Up @@ -173,6 +173,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
return op.getSrcMemref();
}

static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }

static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.getSource();
}
Expand Down Expand Up @@ -397,6 +399,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices);
})
.Case([&](vector::TransferReadOp op) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
Expand Down Expand Up @@ -668,6 +674,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Expand Up @@ -621,3 +621,18 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
// CHECK: nvgpu.ldmatrix %[[ARG0]][%[[ARG1]], %[[ARG2]], %[[ARG3]]] {numTiles = 4 : i32, transpose = false} : memref<4x32x32xf16, 3> -> vector<4x2xf16>

// -----

func.func @fold_vector_load(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
}

// CHECK: func @fold_vector_load
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>

0 comments on commit 5ec360c

Please sign in to comment.