From 11e7138d60cf9fbbdbd734ba44d92e7ff407257c Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Fri, 6 Oct 2023 17:12:33 +0200 Subject: [PATCH] [mlir] Fix `lower_unpack` when dynamic dimensions are involved When lowering `tensor.unpack`, we need to use the sizes of the destination tensor in the final `tensor.extract_slice` operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the `tensor.unpack` operation instead of its destination argument. This would produce invalid IR because the `tensor.dim` operations would need to appear before the `tensor.extract_slice` operation, but the input of the `tensor.dim` operations would consume the final result of the lowering of `tensor.unpack`, which happens after the `tensor.extract_slice` operation. In other words, the definition wouldn't dominate its uses. I.e., we were generating: ``` %dynDim = tensor.dim %defLater, ... <-- %defLater defined below %res = tensor.extract_slice ..., %dynDim, ... %defLater = linalg.copy (ins %res) ``` Note: I checked the implementation of `lower_pack` and the code is correct as far as I can tell. --- .../Dialect/Linalg/Transforms/Transforms.cpp | 2 +- .../Dialect/Linalg/transform-lower-pack.mlir | 39 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8183b40ad7346..bca343cf87771 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -467,7 +467,7 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), SmallVector(destRank, zero), - tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)), + tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector(destRank, one)); // 7. Inject a copy to preserve DPS. diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index c71feddcc1c84..ad6c6a6f6199c 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16 // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32> // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1] // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32> - // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) + // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>) // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>) %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1 : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32> @@ -397,3 +397,40 @@ transform.sequence failures(propagate) { transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">) } + +// ----- + +// Check that we can lower unpack with dynamic dimensions in the destination. +// CHECK-LABEL: func.func @unpack_with_dynamic_dest( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>) +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32> +// CHECK: %[[TRAN:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>) +// CHECK-SAME: permutation = [0, 1, 3, 2, 4] +// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]] +// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1] +// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32> +// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>) +// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>) +func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> { + %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1 + : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32> + return %pack : tensor<32x?x?xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.unpack"> + transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) + -> (!transform.op<"tensor.empty">, + !transform.op<"linalg.transpose">, + !transform.op<"tensor.collapse_shape">, + !transform.op<"tensor.extract_slice">) +}