-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][MemRef] Add subview folding pattern for vector.maskedload #71380
Conversation
@llvm/pr-subscribers-mlir-memref Author: None (tyb0807) ChangesThis is required for fixing iree-org/iree#15031 Full diff: https://github.com/llvm/llvm-project/pull/71380.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 9899c357daeeeb4..e7ab120114e3b0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
+static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
+
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.getSource();
}
@@ -415,6 +417,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices);
})
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), subViewOp.getSource(), sourceIndices, op.getMask(), op.getPassThru());
+ })
.Case([&](vector::TransferReadOp op) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +693,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
+ LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 47ec0e67d99cb3e..3f11e22749bb16d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -654,7 +654,7 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// -----
func.func @fold_vector_load(
- %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
+
+// -----
+
+func.func @fold_vector_maskedload(
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+ return %1 : vector<32xf32>
+}
+
+// CHECK: func @fold_vector_maskedload
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
|
@llvm/pr-subscribers-mlir Author: None (tyb0807) ChangesThis is required for fixing iree-org/iree#15031 Full diff: https://github.com/llvm/llvm-project/pull/71380.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 9899c357daeeeb4..e7ab120114e3b0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
+static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
+
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.getSource();
}
@@ -415,6 +417,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices);
})
+ .Case([&](vector::MaskedLoadOp op) {
+ rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+ op, op.getType(), subViewOp.getSource(), sourceIndices, op.getMask(), op.getPassThru());
+ })
.Case([&](vector::TransferReadOp op) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +693,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
+ LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 47ec0e67d99cb3e..3f11e22749bb16d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -654,7 +654,7 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// -----
func.func @fold_vector_load(
- %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
return %1 : vector<12x32xf32>
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
+
+// -----
+
+func.func @fold_vector_maskedload(
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+ return %1 : vector<32xf32>
+}
+
+// CHECK: func @fold_vector_maskedload
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This is required for fixing iree-org/iree#15031
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
This is required for fixing iree-org/iree#15031