Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Linalg] Support dynamic sizes in lower_unpack #75494

Merged
merged 9 commits into from
Dec 18, 2023

Conversation

srcarroll
Copy link
Contributor

No description provided.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@srcarroll srcarroll changed the title [mlir][linalg] Support dynamic sizes in lower_unpack [MLIR][LINALG] Support dynamic sizes in lower_unpack Dec 14, 2023
@srcarroll srcarroll changed the title [MLIR][LINALG] Support dynamic sizes in lower_unpack [MLIR][Linalg] Support dynamic sizes in lower_unpack Dec 14, 2023
@chelini chelini self-requested a review December 15, 2023 09:22
@srcarroll srcarroll marked this pull request as ready for review December 16, 2023 01:55
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 16, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (srcarroll)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/75494.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+14-6)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+123)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 10dfbe6cec781d..ed81c3b540eed6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -381,11 +381,6 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
     return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
 
   RankedTensorType packedTensorType = unPackOp.getSourceType();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        unPackOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
 
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
@@ -434,8 +429,21 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMinedTensorType, packingMetadata.reassociations);
+
+  // Get dynamic dims from input tensor in order of stripMinedTensor
+  // `tensor.empty` op
+  SmallVector<OpFoldResult, 4> dims =
+      tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
+  applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+  SmallVector<Value, 4> dynDims;
+  dynDims.reserve(stripMinedTensorType.getNumDynamicDims());
+  for (int64_t i = 0; i < stripMinedTensorType.getRank(); i++) {
+    if (stripMinedTensorType.isDynamicDim(i))
+      dynDims.push_back(cast<Value>(dims[i]));
+  }
+
   auto emptyOp =
-      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, dynDims);
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b9706eed54b608..0140290491be09 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -464,6 +464,129 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Check that we can lower unpack with dynamic dimensions in the input and destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
+// CHECK-SAME:   permutation = [0, 2, 1, 3]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME:   : tensor<?x8x?x16xf32> into tensor<?x?xf32>
+//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_with_dynamic_input_dest(%arg0: tensor<?x?x8x16xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %arg1 : tensor<?x?x8x16xf32> -> tensor<?x?xf32>
+    return %unpack : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
+// CHECK-LABEL: func.func @unpack_fully_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//      CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
+//      CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+//      CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
+//      CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
+//      CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
+//      CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME:    ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME:   outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
+// CHECK-SAME:   permutation = [0, 2, 1, 3]
+//      CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+// CHECK-SAME:   : tensor<?x?x?x?xf32> into tensor<?x?xf32>
+//      CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
+// CHECK-SAME:   : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
+// CHECK-SAME:        outs(%[[ARG1]] : tensor<?x?xf32>)
+func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack"> 
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+          -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+      transform.yield
+  }
+}
+
+// -----
+
+// Check that we can lower unpack "as unpad" with dynamic dims 
+// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+//      CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
+//      CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
+//      CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//      CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+//      CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+//      CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+//      CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+// offsets.
+// CHECK-SAME:   [0, 0, 0, 0, 0, 0, 0, 0]
+// sizes.
+// CHECK-SAME:   [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// strides multiplers.
+// CHECK-SAME:   [1, 1, 1, 1, 1, 1, 1, 1]
+// CHECK-SAME:   : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
+func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x?x?x?x?xf32> -> tensor<?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // At the moment, we cannot lower tensor.unpack with outer_dims_perm.
 func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   // expected-note @below {{target payload op}}

@srcarroll
Copy link
Contributor Author

srcarroll commented Dec 16, 2023

Not sure what the deal is with the buildkite failure. Seems unrelated to my changes

Update: Passes now. Maybe it's because I needed to rebase?

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good to me.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/transform-lower-pack.mlir Outdated Show resolved Hide resolved
@chelini chelini merged commit b26ee97 into llvm:main Dec 18, 2023
4 checks passed
@srcarroll srcarroll deleted the implement-dyn-unpack-lowering branch December 18, 2023 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants