Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Fix invalid LoadOp indices being created #76292

Merged
merged 5 commits into from
Jan 3, 2024
Merged

[mlir][vector] Fix invalid LoadOp indices being created #76292

merged 5 commits into from
Jan 3, 2024

Conversation

rikhuijzer
Copy link
Member

@rikhuijzer rikhuijzer commented Dec 23, 2023

Fixes #71326.

This is the second PR. The first PR at #75519 was reverted because an integration test failed. The failed integration test was simplified and added to the core MLIR tests. Compared to the first PR, the current PR uses a more reliable approach. In summary, the current PR determines the mask indices by looking up the mask buffer load indices from the previous iteration, whereas main looks up the indices for the data buffer. The mask and data indices can differ when using a permutation_map.

The cause of the issue was that a new LoadOp was created which looked something like:

func.func main(%arg1 : index, %arg2 : index) {
  %alloca_0 = memref.alloca() : memref<vector<1x32xi1>>
  %1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>>
  %2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>>
  return
}

which crashed inside the LoadOp::verify. Note here that %alloca_0 is the mask as can be seen from the i1 element type and note it is 0 dimensional. Next, %1 has one dimension, but memref.load tries to index it with two indices.

This issue occured in the following code (a simplified version of the bug report):

#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
func.func @main(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
  %c0 = arith.constant 0 : index
  %c0_i32 = arith.constant 0 : i32
  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
  return %3 : vector<1x1x1x1xi32>
}

After this patch, it is lowered to the following by -convert-vector-to-scf:

func.func @main(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
  %c0_i32 = arith.constant 0 : i32
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %alloca = memref.alloca() : memref<vector<1x1x1x1xi32>>
  %alloca_0 = memref.alloca() : memref<vector<1x1xi1>>
  memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>>
  %0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>>
  %1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
  scf.for %arg2 = %c0 to %c1 step %c1 {
    %3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>>
    scf.for %arg3 = %c0 to %c1 step %c1 {
      %4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>>
      scf.for %arg4 = %c0 to %c1 step %c1 {
        %5 = memref.load %1[%arg2] : memref<1xvector<1xi1>>
        %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32>
        memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>>
      }
    }
  }
  %2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>>
  return %2 : vector<1x1x1x1xi32>
}

What was causing the problems is that one dimension of the data buffer %alloca (eltype i32) is unpacked (vector.type_cast) inside the outmost loop (loop with index variable %arg2) and the nested loop (loop with index variable %arg3), whereas the mask buffer %alloca_0 (eltype i1) is not unpacked in these loops.

Before this patch, the load indices would be determined by looking up the load indices for the data buffer load op. However, as shown in the specific example, when a permutation map is specified then the load indices from the data buffer load op start to differ from the indices for the mask op. To fix this, this patch ensures that the load indices for the mask buffer are used instead.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 23, 2023

@llvm/pr-subscribers-mlir

Author: Rik Huijzer (rikhuijzer)

Changes

Second attempt at fixing #71326. The first attempt at #75519 was reverted because an integration test failed.

The cause of the issue was that a new LoadOp was created which looked something like:

func.func main(%arg1 : index, %arg2 : index) {
  %alloca_0 = memref.alloca() : memref&lt;vector&lt;1x32xi1&gt;&gt;
  %1 = vector.type_cast %alloca_0 : memref&lt;vector&lt;1x32xi1&gt;&gt; to memref&lt;1xvector&lt;32xi1&gt;&gt;
  %2 = memref.load %1[%arg1, %arg2] : memref&lt;1xvector&lt;32xi1&gt;&gt;
  return
}

which crashed inside the LoadOp::verify. Note here that %alloca_0 is the mask as can be seen from the i1 element type and note it is 0 dimensional. Next, %1 has one dimension, but memref.load tries to index it with two indices.

This issue occured in the following code (a simplified version of the bug report):

#map1 = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, 0, 0, d3)&gt;
func.func @<!-- -->main(%subview:  memref&lt;1x1x1x1xi32&gt;, %mask: vector&lt;1x1xi1&gt;) -&gt; vector&lt;1x1x1x1xi32&gt; {
  %c0 = arith.constant 0 : index
  %c0_i32 = arith.constant 0 : i32
  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
          : memref&lt;1x1x1x1xi32&gt;, vector&lt;1x1x1x1xi32&gt;
  return %3 : vector&lt;1x1x1x1xi32&gt;
}

After this patch, it is lowered to the following by -convert-vector-to-scf:

func.func @<!-- -->main(%arg0: memref&lt;1x1x1x1xi32&gt;, %arg1: vector&lt;1x1xi1&gt;) -&gt; vector&lt;1x1x1x1xi32&gt; {
  %c0_i32 = arith.constant 0 : i32
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %alloca = memref.alloca() : memref&lt;vector&lt;1x1x1x1xi32&gt;&gt;
  %alloca_0 = memref.alloca() : memref&lt;vector&lt;1x1xi1&gt;&gt;
  memref.store %arg1, %alloca_0[] : memref&lt;vector&lt;1x1xi1&gt;&gt;
  %0 = vector.type_cast %alloca : memref&lt;vector&lt;1x1x1x1xi32&gt;&gt; to memref&lt;1xvector&lt;1x1x1xi32&gt;&gt;
  %1 = vector.type_cast %alloca_0 : memref&lt;vector&lt;1x1xi1&gt;&gt; to memref&lt;1xvector&lt;1xi1&gt;&gt;
  scf.for %arg2 = %c0 to %c1 step %c1 {
    %3 = vector.type_cast %0 : memref&lt;1xvector&lt;1x1x1xi32&gt;&gt; to memref&lt;1x1xvector&lt;1x1xi32&gt;&gt;
    scf.for %arg3 = %c0 to %c1 step %c1 {
      %4 = vector.type_cast %3 : memref&lt;1x1xvector&lt;1x1xi32&gt;&gt; to memref&lt;1x1x1xvector&lt;1xi32&gt;&gt;
      scf.for %arg4 = %c0 to %c1 step %c1 {
        %5 = memref.load %1[%arg2] : memref&lt;1xvector&lt;1xi1&gt;&gt;
        %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref&lt;1x1x1x1xi32&gt;, vector&lt;1xi32&gt;
        memref.store %6, %4[%arg2, %arg3, %arg4] : memref&lt;1x1x1xvector&lt;1xi32&gt;&gt;
      }
    }
  }
  %2 = memref.load %alloca[] : memref&lt;vector&lt;1x1x1x1xi32&gt;&gt;
  return %2 : vector&lt;1x1x1x1xi32&gt;
}

What was causing the problems is that one dimension of the data buffer %alloca (eltype i32) is unpacked (vector.type_cast) inside the outmost loop (loop with index variable %arg2) and the nested loop (loop with index variable %arg3), whereas the mask buffer %alloca_0 (eltype i1) is not unpacked in these loops.

Before this patch, the load indices would be determined by looking up the load indices for the data buffer load op. However, as shown in the specific example, when a permutation map is specified then the load indices from the data buffer load op start to differ from the indices for the mask op. To fix this, this patch ensures that the load indices for the mask buffer are used instead.


Full diff: https://github.com/llvm/llvm-project/pull/76292.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+34-14)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+37)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..13d2513a88804c 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     this->setHasBoundedRewriteRecursion();
   }
 
+  static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+                                       SmallVector<Value, 8> &loadIndices,
+                                       Value iv) {
+    assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+    // Add load indices from the previous iteration.
+    // The mask buffer depends on the permutation map, which makes determining
+    // the indices quite complex, so this is why we need to "look back" to the
+    // previous iteration to find the right indices.
+    Value maskBuffer = getMaskBuffer(xferOp);
+    for (OpOperand &use : maskBuffer.getUses()) {
+      // If there is no previous load op, then the indices are empty.
+      if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
+        Operation::operand_range prevIndices = loadOp.getIndices();
+        loadIndices.append(prevIndices.begin(), prevIndices.end());
+        break;
+      }
+    }
+
+    // In case of broadcast: Use same indices to load from memref
+    // as before.
+    if (!xferOp.isBroadcastDim(0))
+      loadIndices.push_back(iv);
+  }
+
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
     if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
     // Find and cast data buffer. How the buffer can be found depends on OpTy.
     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
-    auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+    Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
-    auto castedDataType = unpackOneDim(dataBufferType);
+    FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
     if (failed(castedDataType))
       return failure();
 
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
     if (xferOp.getMask()) {
-      auto maskBuffer = getMaskBuffer(xferOp);
-      auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+      Value maskBuffer = getMaskBuffer(xferOp);
       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
         // Do not unpack a dimension of the mask, if:
         // * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
       } else {
         // It's safe to assume the mask buffer can be unpacked if the data
         // buffer was unpacked.
-        auto castedMaskType = *unpackOneDim(maskBufferType);
+        auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+        MemRefType castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
                 // If old transfer op has a mask: Set mask on new transfer op.
                 // Special case: If the mask of the old transfer op is 1D and
-                // the
-                //               unpacked dim is not a broadcast, no mask is
-                //               needed on the new transfer op.
+                // the unpacked dim is not a broadcast, no mask is needed on
+                // the new transfer op.
                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
                                          xferOp.getMaskType().getRank() > 1)) {
                   OpBuilder::InsertionGuard guard(b);
                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
 
                   SmallVector<Value, 8> loadIndices;
-                  Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-                  // In case of broadcast: Use same indices to load from memref
-                  // as before.
-                  if (!xferOp.isBroadcastDim(0))
-                    loadIndices.push_back(iv);
-
+                  getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+                                           loadIndices, iv);
                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
                                                        loadIndices);
                   rewriter.updateRootInPlace(newXfer, [&]() {
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..8316b4005cc168 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
 
 //  -----
 
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+  return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+//  -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 2 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+  %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+  %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+  return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+//  -----
+
 // FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
 func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
   // FULL-UNROLL-NOT: vector.extract

@rikhuijzer rikhuijzer merged commit 6b21948 into llvm:main Jan 3, 2024
4 checks passed
@rikhuijzer rikhuijzer deleted the rh/loadop-indices branch January 3, 2024 12:46
@rikhuijzer
Copy link
Member Author

@joker-eph, thanks again for the review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Possible bug in convert-vector-to-scf
3 participants