Skip to content

[mlir][linalg] Fix vectorization of tensor.extract #118105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 5, 2024

Conversation

banach-space
Copy link
Contributor

The example below demonstrates a "scalar read followed by a broadcast"
pattern for tensor.extract:

 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @scalar_broadcast(
    %init : tensor<1x1x3xi32>,
    %src: tensor<1x3x2x4xi32>,
    %idx :index) -> tensor<1x1x3xi32> {

   %c0 = arith.constant 0 :index

   %res = linalg.generic {
     indexing_maps = [#map],
     iterator_types = ["parallel", "parallel", "parallel"]}
     outs(%init : tensor<1x1x3xi32>) {
     ^bb0(%out: i32):
       %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
       linalg.yield %val : i32
   } -> tensor<1x1x3xi32>

   return %res : tensor<1x1x3xi32>
}

The default masking path within the Linalg vectorizer, which assumes an
identity masking map, is not suitable here. Indeed, identity !=
broadcast.

This patch ensures masking is handled in the vectorizeTensorExtract
hook, which has the necessary context for proper handling.

Fixes #116197

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

The example below demonstrates a "scalar read followed by a broadcast"
pattern for tensor.extract:

 #map = affine_map&lt;(d0, d1, d2) -&gt; (d0, d1, d2)&gt;
 func.func @<!-- -->scalar_broadcast(
    %init : tensor&lt;1x1x3xi32&gt;,
    %src: tensor&lt;1x3x2x4xi32&gt;,
    %idx :index) -&gt; tensor&lt;1x1x3xi32&gt; {

   %c0 = arith.constant 0 :index

   %res = linalg.generic {
     indexing_maps = [#map],
     iterator_types = ["parallel", "parallel", "parallel"]}
     outs(%init : tensor&lt;1x1x3xi32&gt;) {
     ^bb0(%out: i32):
       %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor&lt;1x3x2x4xi32&gt;
       linalg.yield %val : i32
   } -&gt; tensor&lt;1x1x3xi32&gt;

   return %res : tensor&lt;1x1x3xi32&gt;
}

The default masking path within the Linalg vectorizer, which assumes an
identity masking map, is not suitable here. Indeed, identity !=
broadcast.

This patch ensures masking is handled in the vectorizeTensorExtract
hook, which has the necessary context for proper handling.

Fixes #116197


Full diff: https://github.com/llvm/llvm-project/pull/118105.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+11-1)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+52)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+15-13)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 06bb6c0fb1cac9..9226392e378191 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1165,8 +1165,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         loc, resultType, extractOp.getTensor(), transferReadIdxs,
         permutationMap, inBounds);
 
+    // Mask this broadcasting xfer_read here rather than relying on the generic
+    // path (the generic path assumes identity masking map, which wouldn't be
+    // valid here).
+    SmallVector<int64_t> readMaskShape{1};
+    auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+    auto allTrue = rewriter.create<vector::ConstantMaskOp>(
+        loc, readMaskType, vector::ConstantMaskKind::AllTrue);
+    auto *maskedReadOp =
+        mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
+
     LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
-    return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+    return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
   }
 
   // 2b. Handle contiguous access.
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 74d23fb5b1e3e1..d0d3b58a057041 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -425,3 +425,55 @@ module attributes {transform.with_named_sequence} {
      transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @scalar_broadcast(%init : tensor<1x1x3xi32>, %src: tensor<1x3x2x4xi32>, %idx :index) -> tensor<1x1x3xi32> {
+
+  %c0 = arith.constant 0 :index
+
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+    outs(%init : tensor<1x1x3xi32>) {
+    ^bb0(%out: i32):
+      %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
+      linalg.yield %val : i32
+  } -> tensor<1x1x3xi32>
+
+  return %res : tensor<1x1x3xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @scalar_broadcast(
+// CHECK-SAME:      %[[INIT:.*]]: tensor<1x1x3xi32>,
+// CHECK-SAME:      %[[SRC:.*]]: tensor<1x3x2x4xi32>,
+// CHECK-SAME:      %[[IDX:.*]]: index) -> tensor<1x1x3xi32> {
+
+/// Compute the mask for saving the final result
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[MASK_RES:.*]] = vector.create_mask %[[C1]], %[[C1_2]], %[[C3]] : vector<1x1x4xi1>
+
+/// Read and broadcast the scalar
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i32
+// CHECK:           %[[MASK_READ:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK_READ]] {
+// CHECK-SAME:          vector.transfer_read %[[SRC]]{{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]],  %[[PAD]]
+// CHECK-SAME:          {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<1x3x2x4xi32>, vector<1x1x4xi32>
+// CHECK-SAME:      } : vector<1xi1> -> vector<1x1x4xi32>
+
+/// Save the result in the output tensor
+// CHECK:           vector.mask %[[MASK_RES]] {
+// CHECK-SAME:        vector.transfer_write %[[READ]], %[[INIT]]{{.*}} {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x3xi32>
+// CHECK-SAME:      } : vector<1x1x4xi1> -> tensor<1x1x3xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index c02405f29bcf7b..1a93d1cd9b7880 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_scalar_broadcast(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %c0 = arith.constant 1 : index
   %c1 = arith.constant 2 : index
   %2 = linalg.generic {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
   return %2 : tensor<1x1x3xf32>
 }
 
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (0, 0, 0)>
-// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_constant_idx(
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_scalar_broadcast(
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<3x3xf32>,
 // CHECK-SAME:      %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG:       %[[C0_f32_2:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:       %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[READ:.*]] = vector.transfer_read  %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32>
-// CHECK:           %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK:           vector.transfer_write %[[READ]], %[[ARG_1]][%[[C0_4]], %[[C0_4]], %[[C0_4]]]  : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> -> vector<1x1x3xf32>
+// CHECK:           %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK:           vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0_2]], %[[C0_2]], %[[C0_2]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
   return %out:tensor<1x1x4xi32>
 }
 
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
 // CHECK-LABEL:   func.func @vectorize_scalar_broadcast_column_tensor(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 4 : index
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
 // CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
 // CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
 // CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
-// CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK:           %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
 // CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+// CHECK:           return %[[VAL_25]] : tensor<1x1x4xi32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details and the fix!

// Mask this broadcasting xfer_read here rather than relying on the generic
// path (the generic path assumes identity masking map, which wouldn't be
// valid here).
SmallVector<int64_t> readMaskShape{1};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SmallVector<int64_t> readMaskShape{1};
SmallVector<int64_t> readMaskShape = {1};

style nit: use assignment syntax.

Both llvm style guide and abseil's totw suggest to not use brace intializer in this case, see https://abseil.io/tips/88 and https://llvm.org/docs/CodingStandards.html#do-not-use-braced-initializer-lists-to-call-a-constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder 🙏🏻

The example below demonstrates a "scalar read followed by a broadcast"
pattern for `tensor.extract`:

```mlir
 #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @scalar_broadcast(
    %init : tensor<1x1x3xi32>,
    %src: tensor<1x3x2x4xi32>,
    %idx :index) -> tensor<1x1x3xi32> {

   %c0 = arith.constant 0 :index

   %res = linalg.generic {
     indexing_maps = [#map],
     iterator_types = ["parallel", "parallel", "parallel"]}
     outs(%init : tensor<1x1x3xi32>) {
     ^bb0(%out: i32):
       %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
       linalg.yield %val : i32
   } -> tensor<1x1x3xi32>

   return %res : tensor<1x1x3xi32>
}
```

The default masking path within the Linalg vectorizer, which assumes an
identity masking map, is not suitable here. Indeed, identity !=
broadcast.

This patch ensures masking is handled in the `vectorizeTensorExtract`
hook, which has the necessary context for proper handling.

Fixes llvm#116197
@banach-space banach-space force-pushed the andrzej/fix_vec_bcast_extract branch from ae7648e to aa35313 Compare December 4, 2024 16:41
@banach-space banach-space merged commit a2acb2f into llvm:main Dec 5, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/fix_vec_bcast_extract branch December 5, 2024 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[linalg] Vectorization failure - masked scalar read + broadcast
3 participants