Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][MemRef] Add subview folding pattern for vector.maskedload #71380

Merged
merged 1 commit into from
Nov 6, 2023

Conversation

tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Nov 6, 2023

This is required for fixing iree-org/iree#15031

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 6, 2023

@llvm/pr-subscribers-mlir-memref

Author: None (tyb0807)

Changes

This is required for fixing iree-org/iree#15031


Full diff: https://github.com/llvm/llvm-project/pull/71380.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+7)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+18-1)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 9899c357daeeeb4..e7ab120114e3b0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
 
 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
 
+static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getSource();
 }
@@ -415,6 +417,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         rewriter.replaceOpWithNewOp<vector::LoadOp>(
             op, op.getType(), subViewOp.getSource(), sourceIndices);
       })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), subViewOp.getSource(), sourceIndices, op.getMask(), op.getPassThru());
+      })
       .Case([&](vector::TransferReadOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +693,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
                LoadOpOfSubViewOpFolder<vector::LoadOp>,
+               LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 47ec0e67d99cb3e..3f11e22749bb16d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -654,7 +654,7 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 // -----
 
 func.func @fold_vector_load(
-  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {  
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 //      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>
+
+// -----
+
+func.func @fold_vector_maskedload(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+  return %1 : vector<32xf32>
+}
+
+//      CHECK: func @fold_vector_maskedload
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 6, 2023

@llvm/pr-subscribers-mlir

Author: None (tyb0807)

Changes

This is required for fixing iree-org/iree#15031


Full diff: https://github.com/llvm/llvm-project/pull/71380.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+7)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+18-1)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 9899c357daeeeb4..e7ab120114e3b0d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -187,6 +187,8 @@ static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
 
 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
 
+static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
+
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.getSource();
 }
@@ -415,6 +417,10 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
         rewriter.replaceOpWithNewOp<vector::LoadOp>(
             op, op.getType(), subViewOp.getSource(), sourceIndices);
       })
+      .Case([&](vector::MaskedLoadOp op) {
+        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
+            op, op.getType(), subViewOp.getSource(), sourceIndices, op.getMask(), op.getPassThru());
+      })
       .Case([&](vector::TransferReadOp op) {
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
@@ -687,6 +693,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
                LoadOpOfSubViewOpFolder<vector::LoadOp>,
+               LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
                StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 47ec0e67d99cb3e..3f11e22749bb16d 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -654,7 +654,7 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 // -----
 
 func.func @fold_vector_load(
-  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {  
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
   %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
   return %1 : vector<12x32xf32>
@@ -665,3 +665,20 @@ func.func @fold_vector_load(
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
 //      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>
+
+// -----
+
+func.func @fold_vector_maskedload(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
+  return %1 : vector<32xf32>
+}
+
+//      CHECK: func @fold_vector_maskedload
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
+//      CHECK:   vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>

Copy link

github-actions bot commented Nov 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@tyb0807 tyb0807 merged commit 5aa2c65 into llvm:main Nov 6, 2023
3 checks passed
@tyb0807 tyb0807 deleted the maskedload branch November 6, 2023 19:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants