-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][linalg] Fix vectorization of tensor.extract #118105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Fix vectorization of tensor.extract #118105
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThe example below demonstrates a "scalar read followed by a broadcast" #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @<!-- -->scalar_broadcast(
%init : tensor<1x1x3xi32>,
%src: tensor<1x3x2x4xi32>,
%idx :index) -> tensor<1x1x3xi32> {
%c0 = arith.constant 0 :index
%res = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%init : tensor<1x1x3xi32>) {
^bb0(%out: i32):
%val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
linalg.yield %val : i32
} -> tensor<1x1x3xi32>
return %res : tensor<1x1x3xi32>
} The default masking path within the Linalg vectorizer, which assumes an This patch ensures masking is handled in the Fixes #116197 Full diff: https://github.com/llvm/llvm-project/pull/118105.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 06bb6c0fb1cac9..9226392e378191 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1165,8 +1165,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
loc, resultType, extractOp.getTensor(), transferReadIdxs,
permutationMap, inBounds);
+ // Mask this broadcasting xfer_read here rather than relying on the generic
+ // path (the generic path assumes identity masking map, which wouldn't be
+ // valid here).
+ SmallVector<int64_t> readMaskShape{1};
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ auto allTrue = rewriter.create<vector::ConstantMaskOp>(
+ loc, readMaskType, vector::ConstantMaskKind::AllTrue);
+ auto *maskedReadOp =
+ mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
+
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+ return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
}
// 2b. Handle contiguous access.
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 74d23fb5b1e3e1..d0d3b58a057041 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -425,3 +425,55 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @scalar_broadcast(%init : tensor<1x1x3xi32>, %src: tensor<1x3x2x4xi32>, %idx :index) -> tensor<1x1x3xi32> {
+
+ %c0 = arith.constant 0 :index
+
+ %res = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ outs(%init : tensor<1x1x3xi32>) {
+ ^bb0(%out: i32):
+ %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32>
+ linalg.yield %val : i32
+ } -> tensor<1x1x3xi32>
+
+ return %res : tensor<1x1x3xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (0, 0, 0)>
+// CHECK-LABEL: func.func @scalar_broadcast(
+// CHECK-SAME: %[[INIT:.*]]: tensor<1x1x3xi32>,
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x3x2x4xi32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<1x1x3xi32> {
+
+/// Compute the mask for saving the final result
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[MASK_RES:.*]] = vector.create_mask %[[C1]], %[[C1_2]], %[[C3]] : vector<1x1x4xi1>
+
+/// Read and broadcast the scalar
+// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
+// CHECK: %[[MASK_READ:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {
+// CHECK-SAME: vector.transfer_read %[[SRC]]{{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]], %[[PAD]]
+// CHECK-SAME: {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<1x3x2x4xi32>, vector<1x1x4xi32>
+// CHECK-SAME: } : vector<1xi1> -> vector<1x1x4xi32>
+
+/// Save the result in the output tensor
+// CHECK: vector.mask %[[MASK_RES]] {
+// CHECK-SAME: vector.transfer_write %[[READ]], %[[INIT]]{{.*}} {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x3xi32>
+// CHECK-SAME: } : vector<1x1x4xi1> -> tensor<1x1x3xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index c02405f29bcf7b..1a93d1cd9b7880 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_scalar_broadcast(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
%c0 = arith.constant 1 : index
%c1 = arith.constant 2 : index
%2 = linalg.generic {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
return %2 : tensor<1x1x3xf32>
}
-// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (0, 0, 0)>
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx(
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_scalar_broadcast(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C0_f32_2:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32>
-// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
-// CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][%[[C0_4]], %[[C0_4]], %[[C0_4]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> -> vector<1x1x3xf32>
+// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK: vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0_2]], %[[C0_2]], %[[C0_2]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
return %out:tensor<1x1x4xi32>
}
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
-// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
+// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+// CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the details and the fix!
// Mask this broadcasting xfer_read here rather than relying on the generic | ||
// path (the generic path assumes identity masking map, which wouldn't be | ||
// valid here). | ||
SmallVector<int64_t> readMaskShape{1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<int64_t> readMaskShape{1}; | |
SmallVector<int64_t> readMaskShape = {1}; |
style nit: use assignment syntax.
Both llvm style guide and abseil's totw suggest to not use brace intializer in this case, see https://abseil.io/tips/88 and https://llvm.org/docs/CodingStandards.html#do-not-use-braced-initializer-lists-to-call-a-constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder 🙏🏻
The example below demonstrates a "scalar read followed by a broadcast" pattern for `tensor.extract`: ```mlir #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @scalar_broadcast( %init : tensor<1x1x3xi32>, %src: tensor<1x3x2x4xi32>, %idx :index) -> tensor<1x1x3xi32> { %c0 = arith.constant 0 :index %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%init : tensor<1x1x3xi32>) { ^bb0(%out: i32): %val = tensor.extract %src[%idx, %idx, %idx, %idx] : tensor<1x3x2x4xi32> linalg.yield %val : i32 } -> tensor<1x1x3xi32> return %res : tensor<1x1x3xi32> } ``` The default masking path within the Linalg vectorizer, which assumes an identity masking map, is not suitable here. Indeed, identity != broadcast. This patch ensures masking is handled in the `vectorizeTensorExtract` hook, which has the necessary context for proper handling. Fixes llvm#116197
Fix SmallVector initialization
ae7648e
to
aa35313
Compare
The example below demonstrates a "scalar read followed by a broadcast"
pattern for
tensor.extract
:The default masking path within the Linalg vectorizer, which assumes an
identity masking map, is not suitable here. Indeed, identity !=
broadcast.
This patch ensures masking is handled in the
vectorizeTensorExtract
hook, which has the necessary context for proper handling.
Fixes #116197