Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Linalg] Support dynamic tiles in lower_pack transform #76003

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

srcarroll
Copy link
Contributor

When an expanded dim is not factorable, emit a tensor.reshape instead of a tensor.expand_shape

@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,

struct LowerPackResult {
tensor::PadOp padOp;
tensor::ExpandShapeOp expandShapeOp;
Operation *expandShapeOp;
Copy link
Contributor Author

@srcarroll srcarroll Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not actually sure what would be appropriate here. Alternatively, we could have two separate fields for the ExpandShapeOp and ReshapeOp, but I haven't looked into the implications of this yet. In some cases (when not all dims are dynamically expanded for example) we technically could do a tensor.reshape + tensor.expand_shape sequence and keep the LowerPackResult struct as is. However, I don't believe it will work when all dims are dynamically expanded. Since expand_shape doesn't allow same rank in and out then it couldn't be used as a no-op. I welcome suggestions.

@srcarroll srcarroll force-pushed the implement-dynamic-pack-lowering branch from 7050c7b to cf0cb00 Compare December 20, 2023 05:02
@srcarroll srcarroll force-pushed the implement-dynamic-pack-lowering branch from cf0cb00 to f14c488 Compare December 20, 2023 05:11
@srcarroll
Copy link
Contributor Author

The current implementation will emit a tensor.reshape op if any of the dims of the input are not factorable (require more than one dynamic dim in the expansion). However, I could instead only emit reshapes for the dims that need it, and then a tensor.expand_shape on the rest of the dims. Whichever is preferable to the reviewer(s).

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 21, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

Changes

When an expanded dim is not factorable, emit a tensor.reshape instead of a tensor.expand_shape


Full diff: https://github.com/llvm/llvm-project/pull/76003.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+54-21)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+46)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
 
   let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
-                      Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+                      Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+                               Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
 
 struct LowerPackResult {
   tensor::PadOp padOp;
-  tensor::ExpandShapeOp expandShapeOp;
+  Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
   linalg::TransposeOp transposeOp;
 };
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..4550589ded6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              tensor::PackOp packOp) {
-  // 1. Filter out NYI cases.
-  auto packedTensorType =
-      cast<RankedTensorType>(packOp->getResultTypes().front());
-  if (llvm::any_of(packOp.getStaticInnerTiles(),
-                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
-    return rewriter.notifyMatchFailure(
-        packOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
   Location loc = packOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
-  // 2. Compute the permutation vector to shuffle packed shape into the shape
+  // 1. Compute the permutation vector to shuffle packed shape into the shape
   // before any outer or inner permutations have been applied. The permutation
   // can be obtained from two permutations:
   //   a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   //   b) Compute the permutation vector to move outer dims if the pack op
   //      has outer_dims_perm.
   // Apply (b) permutation on (a) permutation to get the final permutation.
+  auto packedTensorType =
+      cast<RankedTensorType>(packOp->getResultTypes().front());
   int64_t numPackedDims = packOp.getInnerDimsPos().size();
   int64_t packedRank = packedTensorType.getRank();
   auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
   applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
 
-  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // 2. Compute the stripMinedShape: this is the packed shape before any outer
   // or inner permutations have been applied.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
-  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
   SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
                                  rewriter.getIndexAttr(0));
   SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,65 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                              /*transposeOp=*/nullptr};
     }
   }
-  // 5. Expand from the padded result to the stripMinedShape.
-  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
 
-  // 6. Transpose stripMinedShape to packedShape.
+  // 4. Expand from the padded result to the stripMinedShape.
+  RankedTensorType expandDestType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   SmallVector<int64_t> transpPerm =
       invertPermutationVector(packedToStripMinedShapePerm);
+  Operation *reshapeOp;
+  // Check if any dims are not factorable and thus need a `tensor.reshape`
+  // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+  // requires at most one dynamnic dim
+  if (llvm::any_of(packingMetadata.reassociations,
+                   [&](const auto &rAssoc) -> bool {
+                     return llvm::count_if(rAssoc, [&](int64_t r) {
+                              return stripMinedShape[r] == ShapedType::kDynamic;
+                            }) > 1;
+                   })) {
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+    applyPermutationToVector(sizes, transpPerm);
+    // Create a `tensor` of `index` types for the `shape` operand of
+    // `tensor.reshape`
+    Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+        loc,
+        RankedTensorType::get({expandDestType.getRank()},
+                              rewriter.getIndexType()),
+        ValueRange{});
+    Value shapeTensor = shapeInitTensor;
+    for (const auto &[i, size] : llvm::enumerate(sizes)) {
+      auto maybeConstInt = getConstantIntValue(size);
+      assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+             "expected dynamic dim");
+      Value dim =
+          (maybeConstInt.has_value())
+              ? rewriter
+                    .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+                    .getResult()
+              : cast<Value>(size);
+      shapeTensor = rewriter.create<tensor::InsertOp>(
+          loc, dim, shapeTensor,
+          SmallVector<Value>(
+              {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+    }
+    reshapeOp = rewriter.create<tensor::ReshapeOp>(
+        loc, expandDestType, padOp.getResult(), shapeTensor);
+  } else {
+    reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
+  }
+
+  // 5. Transpose stripMinedShape to packedShape.
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+      loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
 
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
              DBGS() << "reshape op: " << reshapeOp; DBGSNL();
              llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
              DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
 
-  // 7. Replace packOp by transposeOp.
+  // 6. Replace packOp by transposeOp.
   rewriter.replaceOp(packOp, transposeOp->getResults());
 
   return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,52 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME:                            %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME:                            %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:                            %[[TILE0:.*]]: index,
+// CHECK-SAME:                            %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:  %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG:  %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG:  %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG:   %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
+// CHECK:      %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT:                   ^bb0
+// CHECK-NEXT:                    tensor.yield %[[CST]] : f32
+// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:  %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG:  %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT:  %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT:  %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT:  %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT:  %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT:  %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT:  %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3] 
+  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+    : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_as_pad(
 func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
   %cst_0 = arith.constant 0.0 : f32

@srcarroll srcarroll changed the title [mlir][Linalg] Support dynamic shapes in lower_pack transform [mlir][Linalg] Support dynamic tiles in lower_pack transform Dec 21, 2023
@chelini chelini self-requested a review December 21, 2023 16:01
@srcarroll
Copy link
Contributor Author

srcarroll commented Jan 4, 2024

It was suggested to me by @chelini to only have the reshape op to handle all cases and get rid of the expand_shape op. We can then implement a canonicalizer to convert when valid. I'm all for this, however want to make sure this is the direction we want to go before I start making test changes. Because there will be a lot. For each test there will be several extra lines for populating the shape operand of reshape, like (from the test i added in this PR)

// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG:   %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG:  %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG:  %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG:  %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
// CHECK-NEXT:  %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
// CHECK-NEXT:  %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
// CHECK-NEXT:  %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
// CHECK-NEXT:  %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
// CHECK-NEXT:  %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])

Since this is a relatively expensive change, I'd like to get opinions before I do it.

@chelini
Copy link
Contributor

chelini commented Jan 4, 2024

It was suggested to me by @chelini to only have the reshape op to handle all cases and get rid of the expand_shape op. We can then implement a canonicalizer to convert when valid. I'm all for this, however want to make sure this is the direction we want to go before I start making test changes. Because there will be a lot. For each test there will be several extra lines for populating the shape operand of reshape, like (from the test i added in this PR)

// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG:   %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG:  %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG:  %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG:  %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
// CHECK-NEXT:  %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
// CHECK-NEXT:  %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
// CHECK-NEXT:  %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
// CHECK-NEXT:  %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
// CHECK-NEXT:  %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
// CHECK-NEXT:  %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])

Since this is a relatively expensive change, I'd like to get opinions before I do it.

Thank you @srcarroll for pushing on this. Indeed, to generalize the lowering, we would need to emit a reshape operation, and I think it would be better to consistently emit the reshape and then "strength" reduce it to an expanded shape when possible. What folks think here @nicolasvasilache and @hanhanW? Thanks!

@hanhanW
Copy link
Contributor

hanhanW commented Jan 5, 2024

I'm -1 on using tensor.reshape op. IMO, we should only use tensor.expand/collapse_shape; they work much better with existing transformations.

Out of curiosity, what use case do you have in mind? Why do we lower fully dynamic pack op? If it is at high level graph level, we can just use tensor.pack which carries more meaningful information. If it is at low level stage (e.g., around vectorization), I think the inner tile sizes should already be resolved to static values? In this context, we can still use tensor.expand_shape. It supports the case where one dynamic extent can be expanded into a single dynamic extent and other static extents (e.g., ? -> ?x4).

@srcarroll
Copy link
Contributor Author

srcarroll commented Jan 5, 2024

I'm -1 on using tensor.reshape op. IMO, we should only use tensor.expand/collapse_shape; they work much better with existing transformations.

Out of curiosity, what use case do you have in mind? Why do we lower fully dynamic pack op? If it is at high level graph level, we can just use tensor.pack which carries more meaningful information. If it is at low level stage (e.g., around vectorization), I think the inner tile sizes should already be resolved to static values? In this context, we can still use tensor.expand_shape. It supports the case where one dynamic extent can be expanded into a single dynamic extent and other static extents (e.g., ? -> ?x4).

I'll admit I dont know the use cases here. I worked on the lower_unpack transform to support dynamic sizes because someone in discord said they needed it. And then i saw the NYI comment for lower_pack so thought I'd do it

  if (llvm::any_of(packOp.getStaticInnerTiles(),
                   [](int64_t size) { return ShapedType::isDynamic(size); })) {
    return rewriter.notifyMatchFailure(
        packOp,
        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
  }

If you never want to support dynamic tiles, then fine by me. But this shouldn't be a NYI comment if you never intend to support it.

Also that's why i currently have the expandShapeOp either be reshape or expand_shape. What i currently have will still use expand_shape in cases like you mentioned.

@srcarroll
Copy link
Contributor Author

srcarroll commented Jan 5, 2024

It would be easy enough for me to change what I have to only do expand and only match fail on completely impossible (with expand_shape) cases.

I am curious though, what do you mean by more powerful expand_shape op? Wouldn't that just be exactly reshape? We don't need a more powerful expand_shape to cover all the cases you mention above.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 5, 2024

I am curious though, what do you mean by more powerful expand_shape op? Wouldn't that just be exactly reshape? We don't need a more powerful expand_shape to cover all the cases you mention above.

That could be tensor.reshape, but I don't see a scenario about using it. To be honest, I don't have an answer. Bailing out the case is fine to me. If someday people think it is needed, this will help bring up the discussion. To be clear, I am not saying that this is not useful. I just don't know why this is needed.

@srcarroll
Copy link
Contributor Author

To be clear, I am not saying that this is not useful. I just don't know why this is needed.

Fair enough. Me neither. :)

@srcarroll
Copy link
Contributor Author

srcarroll commented Jan 5, 2024

It would be easy enough for me to change what I have to only do expand and only match fail on completely impossible (with expand_shape) cases.

After thinking about it more, if I'm not mistaken, the current implementation already covers all possible cases with expand_shape, but with a caveat. I'll illustrate with this example

func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<4x8x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<4x8x?x?xf32> {
  %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
    : tensor<64x128xf32> -> tensor<4x8x?x?xf32>
  return %pack : tensor<4x8x?x?xf32>
}

The current implementation does not handle this case because it unconditionally match fails when any tile sizes is dynamic. I can make changes on the match failure condition to allow this case with expand_shape and would result in

  func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<4x8x?x?xf32>, %arg2: index, %arg3: index) -> tensor<4x8x?x?xf32> {
    %0 = affine.apply #map()[%arg2]
    %1 = affine.apply #map1()[%arg3]
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0] high[%0, %1] {
    ^bb0(%arg4: index, %arg5: index):
      tensor.yield %cst : f32
    } : tensor<64x128xf32> to tensor<?x?xf32>
    %expanded = tensor.expand_shape %padded [[0, 1], [2, 3]] : tensor<?x?xf32> into tensor<4x?x8x?xf32>
    %transposed = linalg.transpose ins(%expanded : tensor<4x?x8x?xf32>) outs(%arg1 : tensor<4x8x?x?xf32>) permutation = [0, 2, 1, 3] 
    return %transposed : tensor<4x8x?x?xf32>
  }

However, in this example the tile sizes can be inferred by the relationship between input and output sizes, so they might as well be static (I think you eluded to this in one of your comments). But if we allow them to be dynamic, then that can lead to UB.

So I don't think there are any non-trivial cases left to handle dynamic tiles while keeping expand_shape. Am I missing cases?

Questions:

  1. Should the verifier allow such an example? In other words, should users be required to use static tile sizes in the cases where their static values can be inferred?
  2. If we allow, should there be a canonicalizer implementation to infer the static tile values?
  3. Should I just go ahead and allow dynamic tiles for this case in this lowering?

@srcarroll
Copy link
Contributor Author

srcarroll commented Jan 5, 2024

After a very illuminating discussion offline with @chelini, I think we answered some of my questions. So I will relay here

  • Should the verifier allow such an example? In other words, should users be required to use static tile sizes in the cases where their static values can be inferred?

It's not up to us to enforce this. @chelini helped me realize that UB is part of the semantics of the op. So we should allow users to have a dynamic tile size even when there's only one possible tile size that yields well defined behavior, which is currently the case. We did come to the conclusion that maybe a runtime assert should be emitted to enforce well defined behavior.
But that's only for the case where the the output size divides the input size. In the case where it doesn't divide, there are two cases:

  1. The padding_value of tensor.pack is specified. In this case, we can also allow dynamic tile sizes, even though there is still only one that yields well defined behavior. Still emit a runtime assert
  2. The padding_value is not specified. In this case, non-divisible is always UB (because the input will always be padded and thus a value must be specified), so the verifier should fail this case. This is already done with static input size and static tile, but should also cover when tile is dynamic and input/output size are static.

@chelini, did I miss anything here or get anything wrong?

@srcarroll
Copy link
Contributor Author

I made a PR to extend UB cases in verifier #77217

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants