-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Support dynamic tiles in lower_pack
transform
#76003
base: main
Are you sure you want to change the base?
Conversation
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op, | |||
|
|||
struct LowerPackResult { | |||
tensor::PadOp padOp; | |||
tensor::ExpandShapeOp expandShapeOp; | |||
Operation *expandShapeOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not actually sure what would be appropriate here. Alternatively, we could have two separate fields for the ExpandShapeOp
and ReshapeOp
, but I haven't looked into the implications of this yet. In some cases (when not all dims are dynamically expanded for example) we technically could do a tensor.reshape
+ tensor.expand_shape
sequence and keep the LowerPackResult
struct as is. However, I don't believe it will work when all dims are dynamically expanded. Since expand_shape
doesn't allow same rank in and out then it couldn't be used as a no-op. I welcome suggestions.
7050c7b
to
cf0cb00
Compare
cf0cb00
to
f14c488
Compare
The current implementation will emit a |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (srcarroll) ChangesWhen an expanded dim is not factorable, emit a Full diff: https://github.com/llvm/llvm-project/pull/76003.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 77ed9db5e71bd1..4abd3740b57105 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -498,7 +498,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
- Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+ Type<Or<[Transform_ConcreteOpType<"tensor.expand_shape">.predicate,
+ Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op,
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..06e8586f4288b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op,
struct LowerPackResult {
tensor::PadOp padOp;
- tensor::ExpandShapeOp expandShapeOp;
+ Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp`
linalg::TransposeOp transposeOp;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9d230e2c2e5749..4550589ded6df8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -218,21 +218,11 @@ struct PackedOperandsDimList {
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
- // 1. Filter out NYI cases.
- auto packedTensorType =
- cast<RankedTensorType>(packOp->getResultTypes().front());
- if (llvm::any_of(packOp.getStaticInnerTiles(),
- [](int64_t size) { return ShapedType::isDynamic(size); })) {
- return rewriter.notifyMatchFailure(
- packOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
Location loc = packOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
- // 2. Compute the permutation vector to shuffle packed shape into the shape
+ // 1. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
@@ -240,6 +230,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
+ auto packedTensorType =
+ cast<RankedTensorType>(packOp->getResultTypes().front());
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
@@ -259,12 +251,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
- // 3. Compute the stripMinedShape: this is the packed shape before any outer
+ // 2. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
- // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ // 3. Pad the source of packOp to a shape we can expand into stripMinedShape.
SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
@@ -351,24 +343,65 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
/*transposeOp=*/nullptr};
}
}
- // 5. Expand from the padded result to the stripMinedShape.
- auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
- // 6. Transpose stripMinedShape to packedShape.
+ // 4. Expand from the padded result to the stripMinedShape.
+ RankedTensorType expandDestType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
+ Operation *reshapeOp;
+ // Check if any dims are not factorable and thus need a `tensor.reshape`
+ // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion
+ // requires at most one dynamnic dim
+ if (llvm::any_of(packingMetadata.reassociations,
+ [&](const auto &rAssoc) -> bool {
+ return llvm::count_if(rAssoc, [&](int64_t r) {
+ return stripMinedShape[r] == ShapedType::kDynamic;
+ }) > 1;
+ })) {
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, packOp.getDest());
+ applyPermutationToVector(sizes, transpPerm);
+ // Create a `tensor` of `index` types for the `shape` operand of
+ // `tensor.reshape`
+ Value shapeInitTensor = rewriter.create<tensor::EmptyOp>(
+ loc,
+ RankedTensorType::get({expandDestType.getRank()},
+ rewriter.getIndexType()),
+ ValueRange{});
+ Value shapeTensor = shapeInitTensor;
+ for (const auto &[i, size] : llvm::enumerate(sizes)) {
+ auto maybeConstInt = getConstantIntValue(size);
+ assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) &&
+ "expected dynamic dim");
+ Value dim =
+ (maybeConstInt.has_value())
+ ? rewriter
+ .create<arith::ConstantIndexOp>(loc, maybeConstInt.value())
+ .getResult()
+ : cast<Value>(size);
+ shapeTensor = rewriter.create<tensor::InsertOp>(
+ loc, dim, shapeTensor,
+ SmallVector<Value>(
+ {rewriter.create<arith::ConstantIndexOp>(loc, i).getResult()}));
+ }
+ reshapeOp = rewriter.create<tensor::ReshapeOp>(
+ loc, expandDestType, padOp.getResult(), shapeTensor);
+ } else {
+ reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandDestType, padOp.getResult(), packingMetadata.reassociations);
+ }
+
+ // 5. Transpose stripMinedShape to packedShape.
auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
+ loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
- // 7. Replace packOp by transposeOp.
+ // 6. Replace packOp by transposeOp.
rewriter.replaceOp(packOp, transposeOp->getResults());
return LowerPackResult{padOp, reshapeOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 316df431a9c0c8..13d74cbe433264 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -61,6 +61,52 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
+// CHECK: func.func @pack_dyn_tiles(
+// CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[TILE0:.*]]: index,
+// CHECK-SAME: %[[TILE1:.*]]: index
+func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor<?x?x?x?xf32>, %tile_0: index, %tile_1: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
+// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
+// CHECK-NEXT: ^bb0
+// CHECK-NEXT: tensor.yield %[[CST]] : f32
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
+// CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
+// CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
+// CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
+// CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
+// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
+// CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3]
+ %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1
+ : tensor<64x128xf32> -> tensor<?x?x?x?xf32>
+ return %pack : tensor<?x?x?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func.func @pack_as_pad(
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32
|
lower_pack
transformlower_pack
transform
It was suggested to me by @chelini to only have the
Since this is a relatively expensive change, I'd like to get opinions before I do it. |
Thank you @srcarroll for pushing on this. Indeed, to generalize the lowering, we would need to emit a reshape operation, and I think it would be better to consistently emit the reshape and then "strength" reduce it to an expanded shape when possible. What folks think here @nicolasvasilache and @hanhanW? Thanks! |
I'm -1 on using Out of curiosity, what use case do you have in mind? Why do we lower fully dynamic pack op? If it is at high level graph level, we can just use |
I'll admit I dont know the use cases here. I worked on the
If you never want to support dynamic tiles, then fine by me. But this shouldn't be a NYI comment if you never intend to support it. Also that's why i currently have the |
It would be easy enough for me to change what I have to only do expand and only match fail on completely impossible (with I am curious though, what do you mean by more powerful |
That could be tensor.reshape, but I don't see a scenario about using it. To be honest, I don't have an answer. Bailing out the case is fine to me. If someday people think it is needed, this will help bring up the discussion. To be clear, I am not saying that this is not useful. I just don't know why this is needed. |
Fair enough. Me neither. :) |
After thinking about it more, if I'm not mistaken, the current implementation already covers all possible cases with
The current implementation does not handle this case because it unconditionally match fails when any tile sizes is dynamic. I can make changes on the match failure condition to allow this case with
However, in this example the tile sizes can be inferred by the relationship between input and output sizes, so they might as well be static (I think you eluded to this in one of your comments). But if we allow them to be dynamic, then that can lead to UB. So I don't think there are any non-trivial cases left to handle dynamic tiles while keeping Questions:
|
After a very illuminating discussion offline with @chelini, I think we answered some of my questions. So I will relay here
It's not up to us to enforce this. @chelini helped me realize that UB is part of the semantics of the op. So we should allow users to have a dynamic tile size even when there's only one possible tile size that yields well defined behavior, which is currently the case. We did come to the conclusion that maybe a runtime assert should be emitted to enforce well defined behavior.
@chelini, did I miss anything here or get anything wrong? |
I made a PR to extend UB cases in verifier #77217 |
When an expanded dim is not factorable, emit a
tensor.reshape
instead of atensor.expand_shape