Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix lower_unpack when dynamic dimensions are involved #68423

Merged
merged 1 commit into from
Oct 6, 2023

Conversation

qcolombet
Copy link
Collaborator

When lowering tensor.unpack, we need to use the sizes of the destination tensor in the final tensor.extract_slice operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the tensor.unpack operation instead of its destination argument.

This would produce invalid IR because the tensor.dim operations would need to appear before the tensor.extract_slice operation, but the input of the tensor.dim operations would consume the final result of the lowering of tensor.unpack, which happens after the tensor.extract_slice operation. In other words, the definition wouldn't dominate its uses.

I.e., we were generating:

%dynDim = tensor.dim %defLater, ... <-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)

Note: I checked the implementation of lower_pack and the code is correct as far as I can tell.

When lowering `tensor.unpack`, we need to use the sizes of the destination
tensor in the final `tensor.extract_slice` operation.
Prior to this patch, when the destination tensor had dynamic dimensions, we
would compute them from the result of the `tensor.unpack` operation instead
of its destination argument.

This would produce invalid IR because the `tensor.dim` operations would
need to appear before the `tensor.extract_slice` operation, but the input
of the `tensor.dim` operations would consume the final result of the
lowering of `tensor.unpack`, which happens after the
`tensor.extract_slice` operation. In other words, the definition wouldn't
dominate its uses.

I.e., we were generating:
```
%dynDim = tensor.dim %defLater, ... <-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)
```

Note: I checked the implementation of `lower_pack` and the code is correct
as far as I can tell.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Changes

When lowering tensor.unpack, we need to use the sizes of the destination tensor in the final tensor.extract_slice operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the tensor.unpack operation instead of its destination argument.

This would produce invalid IR because the tensor.dim operations would need to appear before the tensor.extract_slice operation, but the input of the tensor.dim operations would consume the final result of the lowering of tensor.unpack, which happens after the tensor.extract_slice operation. In other words, the definition wouldn't dominate its uses.

I.e., we were generating:

%dynDim = tensor.dim %defLater, ... &lt;-- %defLater defined below
%res = tensor.extract_slice ..., %dynDim, ...
%defLater = linalg.copy (ins %res)

Note: I checked the implementation of lower_pack and the code is correct as far as I can tell.


Full diff: https://github.com/llvm/llvm-project/pull/68423.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+1-1)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+38-1)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8183b40ad7346f4..bca343cf8777149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
       loc, destTensorType, reshapeOp->getResult(0),
       SmallVector<OpFoldResult>(destRank, zero),
-      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+      tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
       SmallVector<OpFoldResult>(destRank, one));
 
   // 7. Inject a copy to preserve DPS.
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c71feddcc1c8486..ad6c6a6f6199cc6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
   // CHECK-SAME:   : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
   //      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
   // CHECK-SAME:   : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
-  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) 
+  //      CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
   // CHECK-SAME:        outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
   %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
     : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
   transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
     -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
 }
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
+//      CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
+// CHECK-SAME:   permutation = [0, 1, 3, 2, 4]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
+// CHECK-SAME:   : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
+//      CHECK:  %[[C1:.*]] = arith.constant 1 : index
+//      CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
+//      CHECK: %[[C2:.*]] = arith.constant 2 : index
+//      CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
+// CHECK-SAME:   : tensor<32x32x784xf32> to tensor<32x?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<32x?x?xf32>)
+func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
+  %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
+    : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
+  return %pack : tensor<32x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+    : (!transform.any_op) -> !transform.op<"tensor.unpack">
+  transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+    -> (!transform.op<"tensor.empty">,
+        !transform.op<"linalg.transpose">,
+        !transform.op<"tensor.collapse_shape">,
+        !transform.op<"tensor.extract_slice">)
+}

@qcolombet qcolombet merged commit 7050ff4 into llvm:main Oct 6, 2023
5 checks passed
@qcolombet qcolombet deleted the fix_lower_unpack branch October 6, 2023 20:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants