diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index bc257d17483e3..b39521ac4440c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -498,7 +498,8 @@ def LowerPackOp : Op:$target); let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op, - Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op, + Type.predicate, + Transform_ConcreteOpType<"tensor.reshape">.predicate]>>:$expand_shape_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op); let assemblyFormat = [{ $target attr-dict `:` functional-type(operands, results) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a848d12fbbb50..06e8586f4288b 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1089,7 +1089,7 @@ collapseOpIterationDims(LinalgType op, struct LowerPackResult { tensor::PadOp padOp; - tensor::ExpandShapeOp expandShapeOp; + Operation *expandShapeOp; // `tensor::ExpandShapeOp` or `tensor::ReshapeOp` linalg::TransposeOp transposeOp; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 9d230e2c2e574..4550589ded6df 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -218,21 +218,11 @@ struct PackedOperandsDimList { FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp) { - // 1. Filter out NYI cases. - auto packedTensorType = - cast(packOp->getResultTypes().front()); - if (llvm::any_of(packOp.getStaticInnerTiles(), - [](int64_t size) { return ShapedType::isDynamic(size); })) { - return rewriter.notifyMatchFailure( - packOp, - "non-static shape NYI, needs a more powerful tensor.expand_shape op"); - } - Location loc = packOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(packOp); - // 2. Compute the permutation vector to shuffle packed shape into the shape + // 1. Compute the permutation vector to shuffle packed shape into the shape // before any outer or inner permutations have been applied. The permutation // can be obtained from two permutations: // a) Compute the permutation vector to move the last `numPackedDims` into @@ -240,6 +230,8 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, // b) Compute the permutation vector to move outer dims if the pack op // has outer_dims_perm. // Apply (b) permutation on (a) permutation to get the final permutation. + auto packedTensorType = + cast(packOp->getResultTypes().front()); int64_t numPackedDims = packOp.getInnerDimsPos().size(); int64_t packedRank = packedTensorType.getRank(); auto lastDims = llvm::to_vector( @@ -259,12 +251,12 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, SmallVector packedToStripMinedShapePerm = innerPositionsPerm; applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm); - // 3. Compute the stripMinedShape: this is the packed shape before any outer + // 2. Compute the stripMinedShape: this is the packed shape before any outer // or inner permutations have been applied. SmallVector stripMinedShape(packedTensorType.getShape()); applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); - // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. + // 3. Pad the source of packOp to a shape we can expand into stripMinedShape. SmallVector lows(packOp.getSourceRank(), rewriter.getIndexAttr(0)); SmallVector highs(packOp.getSourceRank(), @@ -351,24 +343,65 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, /*transposeOp=*/nullptr}; } } - // 5. Expand from the padded result to the stripMinedShape. - auto reshapeOp = rewriter.create( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); - // 6. Transpose stripMinedShape to packedShape. + // 4. Expand from the padded result to the stripMinedShape. + RankedTensorType expandDestType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); SmallVector transpPerm = invertPermutationVector(packedToStripMinedShapePerm); + Operation *reshapeOp; + // Check if any dims are not factorable and thus need a `tensor.reshape` + // instead of a `tensor.expand_shape` op. A dim is factorable if the expansion + // requires at most one dynamnic dim + if (llvm::any_of(packingMetadata.reassociations, + [&](const auto &rAssoc) -> bool { + return llvm::count_if(rAssoc, [&](int64_t r) { + return stripMinedShape[r] == ShapedType::kDynamic; + }) > 1; + })) { + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, packOp.getDest()); + applyPermutationToVector(sizes, transpPerm); + // Create a `tensor` of `index` types for the `shape` operand of + // `tensor.reshape` + Value shapeInitTensor = rewriter.create( + loc, + RankedTensorType::get({expandDestType.getRank()}, + rewriter.getIndexType()), + ValueRange{}); + Value shapeTensor = shapeInitTensor; + for (const auto &[i, size] : llvm::enumerate(sizes)) { + auto maybeConstInt = getConstantIntValue(size); + assert((maybeConstInt.has_value() || expandDestType.isDynamicDim(i)) && + "expected dynamic dim"); + Value dim = + (maybeConstInt.has_value()) + ? rewriter + .create(loc, maybeConstInt.value()) + .getResult() + : cast(size); + shapeTensor = rewriter.create( + loc, dim, shapeTensor, + SmallVector( + {rewriter.create(loc, i).getResult()})); + } + reshapeOp = rewriter.create( + loc, expandDestType, padOp.getResult(), shapeTensor); + } else { + reshapeOp = rewriter.create( + loc, expandDestType, padOp.getResult(), packingMetadata.reassociations); + } + + // 5. Transpose stripMinedShape to packedShape. auto transposeOp = rewriter.create( - loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); + loc, reshapeOp->getResult(0), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; DBGSNL(); llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); - // 7. Replace packOp by transposeOp. + // 6. Replace packOp by transposeOp. rewriter.replaceOp(packOp, transposeOp->getResults()); return LowerPackResult{padOp, reshapeOp, transposeOp}; diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 316df431a9c0c..13d74cbe43326 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -61,6 +61,52 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)> +// CHECK: func.func @pack_dyn_tiles( +// CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]] +// CHECK-SAME: %[[ARG1:.*]]: tensor, +// CHECK-SAME: %[[TILE0:.*]]: index, +// CHECK-SAME: %[[TILE1:.*]]: index +func.func @pack_dyn_tiles(%arg0: tensor<64x128xf32>, %arg1: tensor, %tile_0: index, %tile_1: index) -> tensor { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]] +// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]] +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]] +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]] +// CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex> +// CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]] +// CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]] +// CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]] +// CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]] +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]]) +// CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3] + %pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [%tile_0, %tile_1] into %arg1 + : tensor<64x128xf32> -> tensor + return %pack : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module_op + : (!transform.any_op) -> !transform.op<"tensor.pack"> + transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) + -> (!transform.op<"tensor.pad">, !transform.op<"tensor.reshape">, !transform.op<"linalg.transpose">) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @pack_as_pad( func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32