diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 2111a7c581029..5e7945d9b0492 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -24,6 +24,29 @@ namespace mlir { +using ReassociationIndices = SmallVector; + +/// Infer the output shape for a {memref|tensor}.expand_shape when it is +/// possible to do so. +/// +/// Note: This should *only* be used to implement +/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces. +/// If you need to infer the output shape you should use the static method of +/// `ExpandShapeOp` instead of calling this. +/// +/// `inputShape` is the shape of the tensor or memref being expanded as a +/// sequence of SSA values or constants. `expandedType` is the output shape of +/// the expand_shape operation. `reassociation` is the reassociation denoting +/// the output dims each input dim is mapped to. +/// +/// Returns the output shape in `outputShape` and `staticOutputShape`, following +/// the conventions for the output_shape and static_output_shape inputs to the +/// expand_shape ops. +std::optional> +inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, + ArrayRef reassociation, + ArrayRef inputShape); + /// Matches a ConstantIndexOp. detail::op_matcher matchConstantIndex(); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 39e66cd9e6e5a..14b8d95ea15b4 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [ class MemRef_ReassociativeReshapeOp traits = []> : MemRef_Op, - Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ code commonExtraClassDeclaration = [{ @@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp traits = []> : Value getViewSource() { return getSrc(); } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ Example: ```mlir - %r = memref.expand_shape %0 [[0, 1], [2]] - : memref into memref + %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32] + : memref into memref ``` - At most one dimension of a reassociation group (e.g., [0, 1] above) may be - dynamic in the result type. Otherwise, the op would be ambiguous, as it - would not be clear how the source dimension is extended. - If an op can be statically proven to be invalid (e.g, an expansion from `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If it cannot statically be proven invalid (e.g., the full example above; it is @@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ there must be a dynamic result dimension in the corresponding reassociation group. Same for strides. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. + Note: This op currently assumes that the inner strides are of the source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation `output_shape` + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape)>, + + // It will infer output shape using inferOutputShape() method. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation)>, + + // Builder using ReassociationExprs. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + auto reassociationIndices = + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationIndices); }]>, - // Builder using ReassociationExprs. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, + outputShape); }]>, + // Builder that infers the result layout map. The result shape must be + // specified. Otherwise, the op may be ambiguous. The output shape for + // the op will be inferred using the inferOutputShape() method. + OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, + "ArrayRef":$reassociation)>, + // Builder that infers the result layout map. The result shape must be // specified. Otherwise, the op may be ambiguous. OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, - "ArrayRef":$reassociation)> + "ArrayRef":$reassociation, + "ArrayRef":$outputShape)> ]; let extraClassDeclaration = commonExtraClassDeclaration # [{ static FailureOr computeExpandedType( MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation); + + // Infer the output shape for a memref.expand_shape when it is possible + // to do so. + static FailureOr> inferOutputShape( + OpBuilder &b, Location loc, MemRefType expandedType, + ArrayRef reassociation, + ArrayRef inputShape); }]; let hasVerifier = 1; @@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation); + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index cf7f3e89079c1..a403e89a39f98 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp traits = []> : Tensor_Op, Pure])>, - Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>, - Results<(outs AnyRankedTensor:$result)> { + Results<(outs AnyTensor:$result)> { code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrStrName() { return "reassociation"; } @@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp traits = []> : } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1102,43 +1097,75 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { rank than the operand `src` whose dimension sizes are a reassociation of `src`. - A reassociation is defined as a continuous grouping of dimensions. It is - represented with an array of DenseI64ArrayAttr attribute. Entries in the - array are referred to as reassociation maps. + A reassociation is defined as a continuous grouping of dimensions and is + represented with an array of DenseI64ArrayAttr attribute. The reassociation + maps applied to the result tensor with the higher rank must result in the + operand tensor with the smaller rank. - The reassociation maps are applied to the result shape to obtain the operand - shape. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. Example: ```mlir // Dimension expansion i -> (i', j') and (k) -> (k') - %b = tensor.expand_shape %a [[0, 1], [2]] - : tensor into tensor + %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32] + : tensor into tensor ``` }]; + + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation `output_shape` + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape)>, + + // It will infer output shape using inferOutputShape() method. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation)>, + + // Builder using ReassociationExprs. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + auto reassociationIndices = + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationIndices); }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + auto reassociationIndices = + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationIndices, + outputShape); }]> ]; let extraClassDeclaration = commonExtraClassDeclaration # [{ int64_t getCorrespondingSourceDim(int64_t resultDim); + + // Infer the output shape for a tensor.expand_shape when it is possible + // to do so. + static FailureOr> inferOutputShape( + OpBuilder &b, Location loc, RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape); }]; let hasVerifier = 1; @@ -1146,6 +1173,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { let summary = "operation to produce a tensor with a smaller rank"; + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation); let description = [{ The `tensor.collapse_shape` op produces a new tensor of lower (or equal) rank whose dimension sizes are a reassociation of the original `src` dimensions. @@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { : tensor into tensor ``` }]; + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index ae9824f728da4..e8f6edc3f133e 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -62,7 +62,7 @@ getReassociationIndicesAttribute(OpBuilder &b, /// Convert Array> to Array>. SmallVector convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs); + ArrayRef reassociationExprs); /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return std::nullopt when this computation @@ -140,14 +140,11 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, op.getReassociationIndices(), isExpansion); } -/// Verify that shapes of the reshaped types using following rules -/// 1) if a dimension in the collapsed type is static, then the corresponding -/// dimensions in the expanded shape should be +/// Verify that shapes of the reshaped types using following rule: +/// if a dimension in the collapsed type is static, then the corresponding +/// dimensions in the expanded shape should be /// a) static /// b) the product should be same as the collaped shape. -/// 2) if a dimension in the collaped type is dynamic, one and only one of the -/// corresponding dimensions in the expanded type should be dynamic. This -/// rule is only needed with reshape operations that are expanding. LogicalResult reshapeLikeShapesAreCompatible( function_ref emitError, ArrayRef collapsedShape, ArrayRef expandedShape, @@ -156,9 +153,11 @@ LogicalResult reshapeLikeShapesAreCompatible( /// Returns true iff the type is a MemRefType and has a non-identity layout. bool hasNonIdentityLayout(Type type); +enum class ReshapeOpKind { kExpand, kCollapse }; + /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. -template +template struct ComposeReassociativeReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, @@ -181,8 +180,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern { rewriter.getContext()); if (!reassociationIndices) return failure(); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); + + if constexpr (opKind == ReshapeOpKind::kExpand) { + SmallVector outputShape( + getMixedValues(reshapeOp.getStaticOutputShape(), + reshapeOp.getOutputShape(), rewriter)); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices, + outputShape); + } else { + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); + } return success(); } }; @@ -215,7 +224,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern { // /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` /// `reassociation_2` and produce `expand_shape`. -template +template struct ComposeCollapseOfExpandOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CollapseOpTy collapseOp, @@ -322,8 +332,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern { if (!composedReassociation) return failure(); + SmallVector outputShape(getMixedValues( + expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); rewriter.replaceOpWithNewOp( - expandOp, resultType, collapseOp.getSrc(), *composedReassociation); + expandOp, resultType, collapseOp.getSrc(), *composedReassociation, + outputShape); return success(); } diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 20f019666a2e6..594bcf5dbb399 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -125,9 +125,8 @@ SmallVector getMixedValues(ArrayRef staticValues, /// Decompose a vector of mixed static or dynamic values into the /// corresponding pair of arrays. This is the inverse function of /// `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues); +std::pair, SmallVector> +decomposeMixedValues(const SmallVectorImpl &mixedValues); /// Helper to sort `values` according to matching `keys`. SmallVector diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index b6b85cab5a382..7c068d2e94fc2 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt index 2be2724d4a917..07fa58b209b5e 100644 --- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt @@ -8,5 +8,6 @@ add_mlir_dialect_library(MLIRArithUtils MLIRArithDialect MLIRComplexDialect MLIRDialect + MLIRDialectUtils MLIRIR ) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index aa239f5e05396..4ce55a23820cf 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -13,12 +13,74 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" #include using namespace mlir; +std::optional> +mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, + ShapedType expandedType, + ArrayRef reassociation, + ArrayRef inputShape) { + + SmallVector outputShapeValues; + SmallVector outputShapeInts; + // For zero-rank inputs, all dims in result shape are unit extent. + if (inputShape.empty()) { + outputShapeInts.resize(expandedType.getRank(), 1); + return getMixedValues(outputShapeInts, outputShapeValues, b); + } + + // Check for all static shapes. + if (expandedType.hasStaticShape()) { + ArrayRef staticShape = expandedType.getShape(); + outputShapeInts.assign(staticShape.begin(), staticShape.end()); + return getMixedValues(outputShapeInts, outputShapeValues, b); + } + + outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); + for (const auto &it : llvm::enumerate(reassociation)) { + ReassociationIndices indexGroup = it.value(); + + int64_t indexGroupStaticSizesProductInt = 1; + bool foundDynamicShape = false; + for (int64_t index : indexGroup) { + int64_t outputDimSize = expandedType.getDimSize(index); + // Cannot infer expanded shape with multiple dynamic dims in the + // same reassociation group! + if (ShapedType::isDynamic(outputDimSize)) { + if (foundDynamicShape) + return std::nullopt; + foundDynamicShape = true; + } else { + outputShapeInts[index] = outputDimSize; + indexGroupStaticSizesProductInt *= outputDimSize; + } + } + if (!foundDynamicShape) + continue; + + int64_t inputIndex = it.index(); + // Call get() under the assumption that we're not casting + // dynamism. + Value indexGroupSize = inputShape[inputIndex].get(); + Value indexGroupStaticSizesProduct = + b.create(loc, indexGroupStaticSizesProductInt); + Value dynamicDimSize = b.createOrFold( + loc, indexGroupSize, indexGroupStaticSizesProduct); + outputShapeValues.push_back(dynamicDimSize); + } + + if ((int64_t)outputShapeValues.size() != + llvm::count(outputShapeInts, ShapedType::kDynamic)) + return std::nullopt; + + return getMixedValues(outputShapeInts, outputShapeValues, b); +} + /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses matchConstant /// and checks the operation for an index type. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 036005ce9d925..e5f83331baf81 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -607,12 +607,20 @@ struct FoldFillWithTensorReshape : OpRewritePattern { return failure(); Location loc = oldFill.getLoc(); - auto newInit = rewriter.create( - loc, reshapeOp.getResultType(), oldFill.output(), - reshapeOp.getReassociation()); + TensorReshapeOp newInit; + if constexpr (std::is_same::value) { + + newInit = rewriter.create( + loc, reshapeOp.getResultType(), oldFill.output(), + reshapeOp.getReassociation(), reshapeOp.getOutputShape(), + reshapeOp.getStaticOutputShape()); + } else { + newInit = rewriter.create(loc, reshapeOp.getResultType(), + oldFill.output(), + reshapeOp.getReassociation()); + } rewriter.replaceOpWithNewOp(reshapeOp, ValueRange{oldFill.value()}, ValueRange{newInit}); - return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index 420b04b3ee28c..81d44ba04fa1d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -349,7 +349,7 @@ rewriteInIm2Col(RewriterBase &rewriter, SmallVector batchMatVecReassociationIndice = {{0, 1}, {2, 3}}; - Value batchMatVecResultReshaped = rewriter.create( + auto batchMatVecResultReshaped = rewriter.create( loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), batchMatVecReassociationIndice); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 7fd88dec71d49..9a2493a59e019 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -757,7 +757,10 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, ArrayRef innerDimsPos = unPackOp.getInnerDimsPos(); ArrayRef outerDimsPerm = unPackOp.getOuterDimsPerm(); - ArrayRef dstShape = expandOp.getType().getShape(); + auto expandTy = expandOp.getType().dyn_cast(); + if (!expandTy) + return failure(); + ArrayRef dstShape = expandTy.getShape(); SmallVector reassocIndices = expandOp.getReassociationIndices(); // Project inner tile pos to the dim pos after expanding. For example, if dims @@ -796,9 +799,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, nextPos += 1; } - RankedTensorType newExpandType = - tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes, - projectedInnerDimsPos, newOuterDimsPerm); + RankedTensorType newExpandType = tensor::PackOp::inferPackedType( + expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); auto newExpandOp = rewriter.create( expandOp.getLoc(), newExpandType, unPackOp.getSource(), newReassocIndices); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 023ea277bcf49..65efa18af18f6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" @@ -272,8 +273,9 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, assert(rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); - return rewriter.create(loc, origResultType, result, - reassociation); + return rewriter + .create(loc, origResultType, result, reassociation) + .getResult(); } /// Collapse the given `value` so that the type matches the type of @@ -536,9 +538,10 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, resultReplacements.push_back(result); continue; } - resultReplacements.push_back(expandValue(rewriter, loc, result, origDest, - reassociations[opOperandIndex], - options.rankReductionStrategy)); + Value expandedValue = expandValue(rewriter, loc, result, origDest, + reassociations[opOperandIndex], + options.rankReductionStrategy); + resultReplacements.push_back(expandedValue); } rewriter.replaceOp(genericOp, resultReplacements); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 373e9cfc3ce71..89fb4944c0ca3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -625,14 +625,14 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, return success(); } -/// Epanding the body of a linalg operation requires adaptations of the accessed -/// loop indices. Specifically, access of indices in the original operation need -/// to be replaced with linearizations of indices in the expanded op. That -/// requires the shape of the expanded dimensions to be static (at least all but -/// the most significant). For now check that these are all statically sized. -/// Note that this could be extended to handle dynamic case, but the -/// implementation below uses `affine.apply` which seems to have issues when the -/// shapes are not static. +/// Expanding the body of a linalg operation requires adaptations of the +/// accessed loop indices. Specifically, access of indices in the original +/// operation need to be replaced with linearizations of indices in the expanded +/// op. That requires the shape of the expanded dimensions to be static (at +/// least all but the most significant). For now check that these are all +/// statically sized. Note that this could be extended to handle dynamic case, +/// but the implementation below uses `affine.apply` which seems to have issues +/// when the shapes are not static. static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter) { @@ -750,6 +750,31 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, } } +/// Checks if a single dynamic dimension expanded into multiple dynamic +/// dimensions. +static LogicalResult +validateDynamicDimExpansion(LinalgOp linalgOp, + const ExpansionInfo &expansionInfo, + PatternRewriter &rewriter) { + for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { + ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); + if (expandedShape.size() == 1) + continue; + bool foundDynamic = false; + for (int64_t shape : expandedShape) { + if (!ShapedType::isDynamic(shape)) + continue; + if (foundDynamic) { + return rewriter.notifyMatchFailure( + linalgOp, "cannot infer expanded shape with multiple dynamic " + "dims in the same reassociation group"); + } + foundDynamic = true; + } + } + return success(); +} + /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes /// that those conditions have been satisfied. @@ -759,6 +784,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, PatternRewriter &rewriter) { assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) && "preconditions for fuse operation failed"); + + Location loc = linalgOp.getLoc(); // Check if reshape is expanding or collapsing. auto expandingReshapeOp = dyn_cast(*reshapeOp); auto collapsingReshapeOp = dyn_cast(*reshapeOp); @@ -778,6 +805,11 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, expandedType.getShape(), collapsedType.getShape(), rewriter))) return std::nullopt; + // TODO: With the support of multiple dynamic dims expansion in + // tensor.expand_shape op, this case can be handled. + if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter))) + return std::nullopt; + if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter))) return std::nullopt; @@ -816,15 +848,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, /*isExpandingReshape=*/true))) return std::nullopt; expandedOpOperands.push_back(rewriter.create( - linalgOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); + loc, expandedOperandType, opOperand->get(), reassociation)); continue; } } expandedOpOperands.push_back(opOperand->get()); } - Location loc = linalgOp.getLoc(); SmallVector outputs; for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); @@ -843,8 +873,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, /*isExpandingReshape=*/true))) return std::nullopt; outputs.push_back(rewriter.create( - linalgOp.getLoc(), expandedOutputType, opOperand.get(), - reassociation)); + loc, expandedOutputType, opOperand.get(), reassociation)); } else { outputs.push_back(opOperand.get()); } @@ -1615,15 +1644,17 @@ FailureOr mlir::linalg::collapseOpIterationDims( op.getIndexingMapMatchingResult(originalResult.value()); SmallVector reassociation = getOperandReassociation(indexingMap, collapsingInfo); + Value result; if (isa(collapsedOpResult.getType())) { - Value result = rewriter.create( - loc, originalResultType, collapsedOpResult, reassociation); - results.push_back(result); + MemRefType expandShapeResultType = MemRefType::get( + originalResultType.getShape(), originalResultType.getElementType()); + result = rewriter.create( + loc, expandShapeResultType, collapsedOpResult, reassociation); } else { - Value result = rewriter.create( + result = rewriter.create( loc, originalResultType, collapsedOpResult, reassociation); - results.push_back(result); } + results.push_back(result); } else { results.push_back(collapsedOpResult); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 6559c86c9e0ff..5bfdbc6d0bb59 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -114,6 +114,7 @@ FailureOr mlir::linalg::splitReduction( Type newType = RankedTensorType::get( newShape, cast(operand->get().getType()).getElementType()); + Value newInput = b.create( loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 2297bf5e35512..91dfac802ad67 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -329,11 +329,13 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, /*transposeOp=*/nullptr}; } } + // 5. Expand from the padded result to the stripMinedShape. + auto expandShapeResultType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); auto reshapeOp = rewriter.create( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); + loc, expandShapeResultType, padOp.getResult(), + packingMetadata.reassociations); // 6. Transpose stripMinedShape to packedShape. SmallVector transpPerm = diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 836dcb8f329e7..ced7fdd0a90f0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2237,6 +2237,44 @@ FailureOr ExpandShapeOp::computeExpandedType( srcType.getMemorySpace()); } +FailureOr> +ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, + MemRefType expandedType, + ArrayRef reassociation, + ArrayRef inputShape) { + std::optional> outputShape = + inferExpandShapeOutputShape(b, loc, expandedType, reassociation, + inputShape); + if (!outputShape) + return failure(); + return *outputShape; +} + +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + Type resultType, Value src, + ArrayRef reassociation, + ArrayRef outputShape) { + auto [staticOutputShape, dynamicOutputShape] = + decomposeMixedValues(SmallVector(outputShape)); + build(builder, result, resultType.cast(), src, + getReassociationIndicesAttribute(builder, reassociation), + dynamicOutputShape, staticOutputShape); +} + +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + Type resultType, Value src, + ArrayRef reassociation) { + SmallVector inputShape = + getMixedSizes(builder, result.location, src); + MemRefType memrefResultTy = resultType.cast(); + FailureOr> outputShape = inferOutputShape( + builder, result.location, memrefResultTy, reassociation, inputShape); + // Failure of this assertion usually indicates presence of multiple + // dynamic dimensions in the same reassociation group. + assert(succeeded(outputShape) && "unable to infer output shape"); + build(builder, result, memrefResultTy, src, reassociation, *outputShape); +} + void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef resultShape, Value src, ArrayRef reassociation) { @@ -2250,6 +2288,20 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, build(builder, result, *resultType, src, reassociation); } +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultShape, Value src, + ArrayRef reassociation, + ArrayRef outputShape) { + // Only ranked memref source values are supported. + auto srcType = llvm::cast(src.getType()); + FailureOr resultType = + ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation); + // Failure of this assertion usually indicates a problem with the source + // type, e.g., could not get strides/offset. + assert(succeeded(resultType) && "could not compute layout"); + build(builder, result, *resultType, src, reassociation, outputShape); +} + LogicalResult ExpandShapeOp::verify() { MemRefType srcType = getSrcType(); MemRefType resultType = getResultType(); @@ -2266,7 +2318,7 @@ LogicalResult ExpandShapeOp::verify() { if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(), resultType.getShape(), getReassociationIndices(), - /*allowMultipleDynamicDimsPerGroup=*/false))) + /*allowMultipleDynamicDimsPerGroup=*/true))) return failure(); // Compute expected result type (including layout map). @@ -2280,14 +2332,28 @@ LogicalResult ExpandShapeOp::verify() { return emitOpError("expected expanded type to be ") << *expectedResultType << " but found " << resultType; + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape bounds to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + return success(); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp>(context); } /// Compute the layout map after collapsing a given source MemRef type with the @@ -2488,9 +2554,11 @@ struct CollapseShapeOpMemRefCastFolder void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, - CollapseShapeOpMemRefCastFolder>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 9a8c6422a7ff6..7d469198a653c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1063,8 +1063,15 @@ struct ReshapeRewriter : public OpRewritePattern { auto rtp = getRankedTensorType(op.getResult()); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create(loc, denseTp, op.getSrc(), - op.getReassociation()); + ReshapeOp reshape; + if constexpr (std::is_same::value) { + reshape = rewriter.create( + loc, denseTp, op.getSrc(), op.getReassociation(), + op.getOutputShape(), op.getStaticOutputShape()); + } else { + reshape = rewriter.create(loc, denseTp, op.getSrc(), + op.getReassociation()); + } Value convert = rewriter.create(loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 5029ed4aa0387..7a5546bf13757 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1644,6 +1644,44 @@ int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) { llvm_unreachable("could not find reassociation group"); } +FailureOr> +ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, + RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape) { + std::optional> outputShape = + inferExpandShapeOutputShape(b, loc, expandedType, reassociation, + inputShape); + if (!outputShape) + return failure(); + return *outputShape; +} + +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + Type resultType, Value src, + ArrayRef reassociation, + ArrayRef outputShape) { + auto [staticOutputShape, dynamicOutputShape] = + decomposeMixedValues(SmallVector(outputShape)); + build(builder, result, resultType.cast(), src, + getReassociationIndicesAttribute(builder, reassociation), + dynamicOutputShape, staticOutputShape); +} + +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + Type resultType, Value src, + ArrayRef reassociation) { + SmallVector inputShape = + getMixedSizes(builder, result.location, src); + auto tensorResultTy = resultType.cast(); + FailureOr> outputShape = inferOutputShape( + builder, result.location, tensorResultTy, reassociation, inputShape); + // Failure of this assertion usually indicates presence of multiple + // dynamic dimensions in the same reassociation group. + assert(succeeded(outputShape) && "unable to infer output shape"); + build(builder, result, tensorResultTy, src, reassociation, *outputShape); +} + SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } @@ -1727,7 +1765,24 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, } LogicalResult ExpandShapeOp::verify() { - return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); + auto srcType = getSrcType(); + auto resultType = getResultType(); + + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape dims to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + + return verifyTensorReshapeOp(*this, resultType, srcType); } LogicalResult CollapseShapeOp::verify() { @@ -1911,23 +1966,25 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp, - FoldReshapeWithConstant, - FoldReshapeWithSplat, - FoldReshapeWithFromElements, FoldDimOfExpandShape, - FoldDimOfCollapseShape>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp, + FoldReshapeWithConstant, + FoldReshapeWithSplat, + FoldReshapeWithFromElements, FoldDimOfExpandShape, + FoldDimOfCollapseShape>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithSplat, - FoldReshapeWithFromElements, FoldCollapseOfCastOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithSplat, + FoldReshapeWithFromElements, FoldCollapseOfCastOp>( + context); } OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 58ea4cc4da3c3..d078a575f40dd 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -338,6 +338,9 @@ struct ExpandShapeOpInterface // Memref result type is inferred by the builder based on reassociation // indices and result shape. + // TODO: Instead of inferring the output shape argument of + // memref.expand_shape op, use output_shape argument of tensor.expand_shape + // op. replaceOpWithNewBufferizedOp( rewriter, op, tensorResultType.getShape(), *buffer, expandShapeOp.getReassociationIndices()); diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp index 666ac56c6cd5c..7011ce23b55a6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp @@ -52,12 +52,16 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op, struct SimplifyPackToExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, - Type newOperandType, ArrayAttr reassociation) const { + FailureOr + insertExpand(RewriterBase &rewriter, Location loc, Value operand, + Type newOperandType, + ArrayRef reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter.create(loc, newOperandType, operand, - reassociation); + return rewriter + .create(loc, newOperandType, operand, + reassociation) + .getResult(); } /// Returns success() if it is only packing on the innermost dimension. @@ -96,10 +100,14 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); - Value expanded = insertExpand( - rewriter, packOp.getLoc(), packOp.getSource(), destType, - getReassociationIndicesAttribute(rewriter, *reassociation)); - rewriter.replaceOp(packOp, expanded); + FailureOr expanded = + insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType, + *reassociation); + if (failed(expanded)) { + return rewriter.notifyMatchFailure( + packOp, "unable to expand source of tensor.pack"); + } + rewriter.replaceOp(packOp, *expanded); return success(); } }; diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 41c7af4593c77..e4f387d40ced2 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -168,7 +168,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute( } SmallVector mlir::convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs) { + ArrayRef reassociationExprs) { SmallVector reassociationIndices; for (const auto &exprs : reassociationExprs) { ReassociationIndices indices; @@ -230,24 +230,17 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible( ArrayRef reassociationMaps, bool isExpandingReshape) { unsigned expandedDimStart = 0; for (const auto &map : llvm::enumerate(reassociationMaps)) { - std::optional dynamicShape; + bool foundDynamicShape = false; int64_t linearizedStaticShape = 1; + for (const auto &dim : llvm::enumerate( expandedShape.slice(expandedDimStart, map.value().size()))) { - if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicShape) { - return emitError("invalid to have a single dimension (" + - Twine(map.index()) + - ") expanded into multiple dynamic dims (" + - Twine(expandedDimStart + dynamicShape.value()) + - "," + Twine(expandedDimStart + dim.index()) + ")"); - } - dynamicShape = dim.index(); - } else { + if (ShapedType::isDynamic(dim.value())) + foundDynamicShape = true; + else linearizedStaticShape *= dim.value(); - } } - if (dynamicShape) { + if (foundDynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return emitError( "expected dimension " + Twine(map.index()) + diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 1e8197e109442..74a53709592dd 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -180,9 +180,8 @@ SmallVector getMixedValues(ArrayRef staticValues, /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues) { +std::pair, SmallVector> +decomposeMixedValues(const SmallVectorImpl &mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { @@ -193,7 +192,7 @@ decomposeMixedValues(Builder &b, dynamicValues.push_back(it.get()); } } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; + return {staticValues, dynamicValues}; } /// Helper to sort `values` according to matching `keys`. diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 87d613986c7c3..b86103422b074 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -453,7 +453,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout( func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { // Reshapes that expand a contiguous tensor with some 1's. - %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] + %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> return %0 : memref<1x3x4x1x5xf32> } @@ -510,7 +510,7 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref // ----- func.func @expand_shape_zero_dim(%arg0 : memref) -> memref<1x1xf32> { - %0 = memref.expand_shape %arg0 [] : memref into memref<1x1xf32> + %0 = memref.expand_shape %arg0 [] output_shape [1, 1] : memref into memref<1x1xf32> return %0 : memref<1x1xf32> } @@ -571,13 +571,13 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32> // ----- -func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { - %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32> +func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<1x2x?xf32> { + %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0]: memref<1x?xf32> into memref<1x2x?xf32> return %0 : memref<1x2x?xf32> } // CHECK-LABEL: func.func @expand_shape_dynamic( -// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> { +// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> { // CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, // CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, @@ -614,15 +614,15 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> { // ----- func.func @expand_shape_dynamic_with_non_identity_layout( - %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) -> + %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>, %sz0: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { - %0 = memref.expand_shape %arg0 [[0], [1, 2]]: + %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0] : memref<1x?xf32, strided<[?, ?], offset: ?>> into memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> } // CHECK-LABEL: func.func @expand_shape_dynamic_with_non_identity_layout( -// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { +// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { // CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, // CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 37999d6fc14ad..baf9cfe610a5a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -334,9 +334,9 @@ memref.global "private" @gv4 : memref = dense<1.0> {alignment = 64} // CHECK-LABEL: func @expand_shape_static( // CHECK-SAME: %[[ARG:.*]]: memref<{{.*}}>) func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { - // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] + // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] // Reshapes that expand a contiguous tensor with some 1's. - %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] + %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> return %0 : memref<1x3x4x1x5xf32> } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 3ec15221e2999..03c9fec1c9a83 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -348,7 +348,7 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor, %arg1: tensor, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32> + // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32> // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32> // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) { // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32): @@ -871,7 +871,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 // CHECK: } - // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32> + // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32> %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32> // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32> @@ -882,7 +882,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () { // CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield [[RES]] : f32 // CHECK: } - // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32> %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xf32>) -> tensor<5x1xf32> // CHECK: arith.constant 1.0 @@ -920,7 +920,10 @@ func.func @reduce_float_dyn(%arg0: tensor) -> () { // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 // CHECK: } - // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor into tensor + // CHECK: %[[C0_0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor + // CHECK: %[[C1:.+]] = arith.constant 1 : index + // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [%[[DIM_1]], 1, 4] : tensor into tensor %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor) -> tensor return } @@ -938,7 +941,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor) -> () { // CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 // CHECK: } - // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] : tensor into tensor<1xf32> + // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] output_shape [1] : tensor into tensor<1xf32> %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor) -> tensor<1xf32> return } @@ -958,7 +961,10 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () { // CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[RES]] : f32 // CHECK: } - // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C1_0]] : tensor<5x?xf32> + // CHECK: %[[C1_2:.+]] = arith.constant 1 : index + // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [5, %[[DIM_1]], 1] : tensor<5x?xf32> into tensor<5x?x1xf32> %0 = tosa.reduce_prod %arg0 {axis = 2 : i32} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return } @@ -978,7 +984,10 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor) -> () { // CHECK: %[[MAX:.+]] = arith.maximumf %[[ARG1]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[MAX]] : f32 // CHECK: } - // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] : tensor into tensor + // CHECK: %[[C0_0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor + // CHECK: %[[C1_2:.+]] = arith.constant 1 : index + // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] output_shape [%[[DIM_1]], 1] : tensor into tensor %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor) -> tensor return } @@ -996,7 +1005,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 // CHECK: } - // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32> + // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32> %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32> // CHECK: [[INIT:%.+]] = tensor.empty() @@ -1007,7 +1016,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () { // CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32 // CHECK: linalg.yield [[RES]] : i32 // CHECK: } - // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32> + // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xi32> into tensor<5x1xi32> %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x1xi32> // CHECK: arith.constant 1 @@ -1043,7 +1052,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () { // CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1 // CHECK: linalg.yield [[RES]] : i1 // CHECK: } - // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1> + // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi1> into tensor<1x4xi1> %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1> // CHECK: arith.constant false diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index a8a3c42e16842..b8c3d56f21f10 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -14,7 +14,7 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor) -> tensor // CHECK-LABEL: test_reshape_0d_up_s2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor into tensor<1xf32> +// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor) -> tensor { @@ -26,7 +26,7 @@ func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor) -> tensor { // CHECK-LABEL: test_reshape_0d_up_s2d_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor into tensor<1xf32> +// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor // CHECK: return %[[VAL_1]] : tensor func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor) -> tensor { @@ -38,7 +38,7 @@ func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor) -> tensor // CHECK-LABEL: test_reshape_0d_up_s2s_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor into tensor<1xf32> +// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: return %[[VAL_0]] : tensor<1xf32> func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor) -> tensor<1xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<1xf32> @@ -49,7 +49,7 @@ func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor) -> tensor<1xf32> { // CHECK-LABEL: test_reshape_0d_up_s2s_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor into tensor<1xf32> +// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: return %[[VAL_0]] : tensor<1xf32> func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor) -> tensor<1xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<1xf32> @@ -83,8 +83,12 @@ func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor into tensor<2x?xf32> -// CHECK: return %[[VAL_0]] : tensor<2x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C2]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, %[[VAL_0]]] : tensor into tensor<2x?xf32> +// CHECK: return %[[EXPANDED]] : tensor<2x?xf32> func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> @@ -94,7 +98,7 @@ func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> // CHECK-LABEL: test_reshape_1d_up_s2s_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<6xf32> -// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> +// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> // CHECK: return %[[VAL_0]] : tensor<2x3xf32> func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<6xf32>) -> tensor<2x3xf32> @@ -128,8 +132,12 @@ func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6 // CHECK-LABEL: test_reshape_2d_same_d2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor into tensor<2x?xf32> -// CHECK: return %[[VAL_1]] : tensor<2x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C2]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, %[[DIV]]] : tensor into tensor<2x?xf32> +// CHECK: return %[[EXPANDED]] : tensor<2x?xf32> func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor) -> tensor<2x?xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x?xf32> return %0 : tensor<2x?xf32> @@ -140,7 +148,7 @@ func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor) -> tensor<2x?xf // CHECK-LABEL: test_reshape_2d_same_s2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32> // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32> -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32> +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32> // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor { @@ -153,7 +161,7 @@ func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32> -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32> +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32> // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor { @@ -166,7 +174,7 @@ func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> // CHECK: return %[[VAL_1]] : tensor<2x3xf32> func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xf32>) -> tensor<2x3xf32> @@ -178,7 +186,11 @@ func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2 // CHECK-LABEL: test_reshape_3d_same_d2d_auto_empty // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32> // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor<0x3x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C0_0]] : index +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [0, 3, %[[DIV]]] : tensor into tensor<0x3x?xf32> // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor { @@ -191,7 +203,11 @@ func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tens // CHECK-LABEL: test_reshape_3d_same_d2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?x?xf32> // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<2x?x?xf32> into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor<2x?x4xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor into tensor<2x?x4xf32> // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor { @@ -204,7 +220,11 @@ func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor<2x3x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index +// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, 3, %[[DIV]]] : tensor into tensor<2x3x?xf32> // CHECK: return %[[VAL_1]] : tensor<2x3x?xf32> func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor) -> tensor<2x3x?xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x?xf32> @@ -216,8 +236,12 @@ func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor) -> t // CHECK-LABEL: test_reshape_3d_same_d2d_explicit_empty // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32> // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor to tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 2] : tensor into tensor +// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2x?xf32>) -> tensor @@ -229,8 +253,12 @@ func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> // CHECK-LABEL: test_reshape_3d_same_d2d_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor to tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C12:.*]] = arith.constant 12 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor into tensor +// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor // CHECK: return %[[VAL_2]] : tensor func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor) -> tensor { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor @@ -253,8 +281,12 @@ func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor) // CHECK-LABEL: test_reshape_3d_same_d2s_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor<2x?x4xf32> -// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<2x3x4xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor into tensor<2x?x4xf32> +// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x?x4xf32> to tensor<2x3x4xf32> // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32> func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor) -> tensor<2x3x4xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x4xf32> @@ -266,8 +298,12 @@ func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor) -> tensor<2x3 // CHECK-LABEL: test_reshape_3d_same_d2s_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor to tensor<2x3x4xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor +// CHECK: %[[C12:.*]] = arith.constant 12 : index +// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor into tensor +// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor<2x3x4xf32> // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32> func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor) -> tensor<2x3x4xf32> { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<2x3x4xf32> @@ -288,10 +324,14 @@ func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>) // CHECK-LABEL: test_reshape_3d_up_d2s_explicit // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : tensor into tensor -// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor to tensor<1x3x2x1xf32> -// CHECK: return %[[VAL_2]] : tensor<1x3x2x1xf32> +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor into tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor +// CHECK: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [%[[VAL_0]], 3, 2, 1] : tensor into tensor +// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor to tensor<1x3x2x1xf32> +// CHECK: return %[[CAST]] : tensor<1x3x2x1xf32> func.func @test_reshape_3d_up_d2s_explicit(%input: tensor) -> tensor<1x3x2x1xf32> { %0 = tosa.reshape %input {new_shape = array} : (tensor) -> tensor<1x3x2x1xf32> return %0 : tensor<1x3x2x1xf32> @@ -313,9 +353,13 @@ func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor) -> tens // CHECK-LABEL: test_reshape_5d_down_d2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: return %[[VAL_1]] : tensor +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor into tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor +// CHECK: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 2, 3] : tensor into tensor +// CHECK: return %[[EXPANDED]] : tensor func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor) -> tensor { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor return %0 : tensor @@ -325,9 +369,13 @@ func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor) -> tensor // CHECK-LABEL: test_reshape_6d_down_d2d_auto // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x?x5x7x11xf32> -// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor -// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor into tensor -// CHECK: return %[[VAL_1]] : tensor +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor +// CHECK: %[[C385:.*]] = arith.constant 385 : index +// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C385]] : index +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 5, 77] : tensor into tensor +// CHECK: return %[[EXPANDED]] : tensor func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<1x2x?x5x7x11xf32>) -> tensor return %0 : tensor diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 9a3e14b6d3917..efe59af97d964 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -132,7 +132,7 @@ func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { %cst = arith.constant 8.0 : f32 %0 = tensor.empty() : tensor<128xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32> - %2 = tensor.expand_shape %1 [[0, 1, 2]] + %2 = tensor.expand_shape %1 [[0, 1, 2]] output_shape [1, 1, 128] : tensor<128xf32> into tensor<1x1x128xf32> %3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1] : tensor<1x1x128xf32> into tensor<5x6x128xf32> diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir index 0e353a1fa43fc..4bf81820f0e80 100644 --- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir @@ -165,7 +165,9 @@ func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> { %init = tensor.empty(%width) : tensor<1x?xf32> %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32> %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor - %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor into tensor<1x1x1x?xf32> + %c0 = arith.constant 0 : index + %sz0 = tensor.dim %slice, %c0 : tensor + %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor into tensor<1x1x1x?xf32> return %expand : tensor<1x1x1x?xf32> } diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir index 547320f533874..61bedecbdca5a 100644 --- a/mlir/test/Dialect/Linalg/collapse-dim.mlir +++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir @@ -52,7 +52,7 @@ func.func @collapse_parallel( // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} // CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) { // CHECK: } -> tensor<2x32x40960xf32> -// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32> +// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] output_shape [2, 32, 10, 4096] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32> // ----- @@ -127,8 +127,8 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem // CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> // CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32> -// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32> -// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64> +// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32> +// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: } diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir index a643199635312..c7c846d7ecc9c 100644 --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -50,7 +50,7 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 // CHECK: IR printer: transformed -// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> @@ -78,7 +78,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<1x196x16xf32> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> // CHECK: return %[[RESULT]] func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { @@ -204,7 +204,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<8x196x16xf32> -// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [8, 14, 14, 16] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32> // CHECK: return %[[CS_FINAL]] func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> { %0 = linalg.conv_2d_nhwc_hwcf @@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<8x16x196xf32> -// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> +// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32> // CHECK: return %[[CS_FINAL]] func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> { %0 = linalg.conv_2d_nchw_fchw @@ -310,7 +310,7 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32 // CHECK: IR printer: transformed -// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> @@ -338,7 +338,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor<1x196x16xf32> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> // CHECK: return %[[RESULT]] func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { @@ -378,7 +378,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32 // CHECK: linalg.yield %[[ADD]] : i32 // CHECK: } -> tensor<1x196x16xi32> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32> // CHECK: return %[[RESULT]] func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> { @@ -416,7 +416,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex // CHECK: linalg.yield %[[ADD]] : complex // CHECK: } -> tensor<1x196x16xcomplex> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> // CHECK: return %[[RESULT]] func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xcomplex>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { @@ -459,7 +459,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex // CHECK: linalg.yield %[[ADD]] : complex // CHECK: } -> tensor<1x196x16xcomplex> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> // CHECK: return %[[RESULT]] func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xcomplex>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { @@ -500,7 +500,7 @@ module attributes {transform.with_named_sequence} { // CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex // CHECK: linalg.yield %[[ADD]] : complex // CHECK: } -> tensor<1x196x16xcomplex> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> // CHECK: return %[[RESULT]] func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 79d61ab757e32..bee08503298fd 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -988,17 +988,20 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4 // ----- -func.func @push_down_unpack_through_expand(%5: tensor, %dim: index) -> tensor { +func.func @push_down_unpack_through_expand(%5: tensor, %dim: index, %sz0: index) -> tensor { %6 = tensor.empty(%dim) : tensor %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor -> tensor - %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor into tensor + %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor into tensor func.return %expanded : tensor } // CHECK-LABEL: func.func @push_down_unpack_through_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[C32:.+]] = arith.constant 32 : index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor into tensor +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor into tensor // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor -> tensor @@ -1009,12 +1012,12 @@ func.func @push_down_unpack_through_expand(%5: tensor, %dim: index func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> { %6 = tensor.empty() : tensor<4x3072x256xf32> %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32> - %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32> + %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32> func.return %expanded : tensor<4x12x256x256xf32> } // CHECK-LABEL: @push_down_permuted_unpack_through_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] output_shape [4, 32, 12, 32, 8, 8] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32> // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32> // CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32> @@ -1024,29 +1027,32 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32> func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> { %6 = tensor.empty() : tensor<48x256xf32> %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32> - %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32> + %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] output_shape [3, 16, 1, 256] : tensor<48x256xf32> into tensor<3x16x1x256xf32> func.return %expanded : tensor<3x16x1x256xf32> } // CHECK-LABEL: func.func @push_down_unpack_through_unit_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] output_shape [3, 2, 1, 32, 8, 8] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32> // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32> // CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32> // ----- -func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor, %dim: index) -> tensor { +func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor, %dim: index, %sz0: index) -> tensor { %6 = tensor.empty(%dim) : tensor %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor -> tensor - %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor into tensor + %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor into tensor func.return %expanded : tensor } // CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[C256:.+]] = arith.constant 256 : index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor into tensor +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C256]] : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [%[[SZ0]], 256, 32, 8] : tensor into tensor // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor -> tensor @@ -1057,11 +1063,11 @@ func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor, func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> { %6 = tensor.empty() : tensor<3072x256xf32> %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32> - %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32> + %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> func.return %expanded : tensor<256x12x256xf32> } // CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32> // CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32> diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index c140b6abcc37a..a9cbaaf7fdc48 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -25,13 +25,22 @@ func.func @drop_one_trip_loops(%arg0 : tensor, %arg1 : f32, %shape: t // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK-LABEL: func @drop_one_trip_loops -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] -// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]] +// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]] +// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]] +// CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]] +// CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]] +// CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]] +// CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]] +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor into tensor // CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> @@ -70,13 +79,18 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, } // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)> // CHECK-LABEL: func @drop_one_trip_loops_all_ones +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: tensor.collapse_shape %{{.*}} [] // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]] // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]] +// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32> +// CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]] +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor into tensor<1x1x?x1x1xf32> // ----- @@ -232,8 +246,8 @@ func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor func.func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32> - %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32> + %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : tensor<5xf32> into tensor<1x5xf32> + %1 = tensor.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32> %2 = linalg.generic #trait ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) outs(%shape : tensor<5x5xf32>) { @@ -331,7 +345,6 @@ func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor< // CHECK: func @fold_unit_dim_for_empty_tensor - // CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32> // CHECK: %[[INIT:.+]] = tensor.empty() : tensor // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor) -> tensor @@ -340,7 +353,7 @@ func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor< // CHECK-SAME: iterator_types = ["reduction"] // CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] : tensor into tensor<1xf32> +// CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] output_shape [1] : tensor into tensor<1xf32> // CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32> @@ -364,11 +377,11 @@ func.func @fold_slice( // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]] // CHECK-SAME: to tensor // CHECK: %[[RESULT1:.+]] = tensor.expand_shape %[[SLICE1]] -// CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] +// CHECK-SAME: {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor into tensor<1x?x?x1x?x1x1xf32> // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]] // CHECK-SAME: to tensor // CHECK: %[[RESULT2:.+]] = tensor.expand_shape %[[SLICE2]] -// CHECK-SAME: [0, 1], [2], [3, 4, 5, 6] +// CHECK-SAME: {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor into tensor<1x?x?x1x?x1x1xf32> // CHECK: return %[[RESULT1]], %[[RESULT2]] // ----- @@ -391,20 +404,27 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> } -> tensor<1x?xf32> return %3 : tensor<1x?xf32> } -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)> // CHECK: func @unit_dim_for_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] -// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[C3:.+]] = arith.constant 3 : index +// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C3]] : tensor<1x?x1x?xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] -// CHECK: return %[[RESULT_RESHAPE]] +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32> +// CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]] +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor into tensor<1x?xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x?xf32> // ----- @@ -437,7 +457,7 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1 // CHECK-SAME: iterator_types = ["parallel"] // CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) // CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, 1] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -460,20 +480,28 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor) -> tensor tensor return %3 : tensor } -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> // CHECK: func @unit_dim_for_reduction_inner // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] -// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor -// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C2]] : tensor +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]] +// CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor) -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] -// CHECK: return %[[RESULT_RESHAPE]] +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]] +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor into tensor +// CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -484,7 +512,7 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { // CHECK-LABEL: func @slice_unit_dims // CHECK: %[[SLICE:.+]] = tensor.extract_slice // CHECK-SAME: tensor<1x3xf32> to tensor -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] [] +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] [] output_shape [1, 1] // CHECK: return %[[RESULT]] // ----- @@ -496,7 +524,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x // CHECK-LABEL: func @rank_reduced_extract_slice // CHECK: %[[SLICE:.+]] = tensor.extract_slice // CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32> -// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] output_shape [1, 3, 3] // CHECK: return %[[RESULT]] // ----- @@ -709,8 +737,8 @@ func.func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref func.func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32> { - %0 = memref.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32> - %1 = memref.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32> + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : memref<5xf32> into memref<1x5xf32> + %1 = memref.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : memref<5xf32> into memref<5x1xf32> linalg.generic #trait ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>) outs(%shape : memref<5x5xf32>) { @@ -966,7 +994,7 @@ func.func @drop_unit_pad_dims(%arg0: tensor<1x1x3x1x1xf32>) -> tensor<1x2x3x1x3x // CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 0] high[0, 0, 2] // CHECK: } : tensor<1x3x1xf32> to tensor<2x3x3xf32> // CHECK: tensor.expand_shape %[[PADDED]] -// CHECK-SAME: {{\[}}[0, 1], [2, 3], [4]{{\]}} : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32> +// CHECK-SAME: {{\[}}[0, 1], [2, 3], [4]{{\]}} output_shape [1, 2, 3, 1, 3] : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32> // CHECK-SLICES-LABEL: func @drop_unit_pad_dims // CHECK-SLICES: %[[EXTRACT:.+]] = tensor.extract_slice @@ -989,13 +1017,19 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32> return %0 : tensor<1x?xf32> } +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)> // CHECK-LABEL: func @drop_unit_pad_dynamic_dims +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape // CHECK-SAME: {{\[}}[0, 1]{{\]}} : tensor<1x?xf32> into tensor // CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6] // CHECK: } : tensor to tensor -// CHECK: tensor.expand_shape %[[PADDED]] -// CHECK-SAME: {{\[}}[0, 1]{{\]}} : tensor into tensor<1x?xf32> +// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32> +// CHECK: %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]] +// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor into tensor<1x?xf32> // CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)> @@ -1052,4 +1086,4 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0] high[0, 0] // CHECK: } : tensor<383x128xf32> to tensor<384x128xf32> // CHECK: tensor.expand_shape %[[PADDED]] -// CHECK-SAME: {{\[}}[0, 1], [2]] : tensor<384x128xf32> into tensor<1x384x128xf32> +// CHECK-SAME: {{\[}}[0, 1], [2]] output_shape [1, 384, 128] : tensor<384x128xf32> into tensor<1x384x128xf32> diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir index 5a27fe76b1341..9fe50a521d2d8 100644 --- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -26,7 +26,7 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32> // CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]] // CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>) -// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] output_shape [32, 7] : tensor<224xf32> into tensor<32x7xf32> func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> { %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32> return %0 : tensor<32x7xf32> diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir index 50d308b6a9fee..0d40df534a3bb 100644 --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -9,8 +9,7 @@ #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { - %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] - : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], @@ -40,7 +39,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] : // CHECK-SAME: outs(%[[INIT_RESHAPE]] : -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] // CHECK: return %[[RESULT_RESHAPE]] // CONTROL: func @fuse_by_collapsing( @@ -60,8 +59,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> func.func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>, %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { - %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] - : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], @@ -122,8 +120,7 @@ func.func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>, #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi32>, %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { - %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] - : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32> + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [9, 7, 8, 2, 3, 4, 5, 6] : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32> %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32> %generic = linalg.generic { indexing_maps = [#map0, #map1, #map2, #map3], @@ -154,7 +151,7 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3 // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] : // CHECK-SAME: outs(%[[INIT_RESHAPE]] : -// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9] // CHECK: return %[[RESULT_RESHAPE]] // ----- @@ -165,11 +162,11 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> func.func @fuse_by_collapsing_dynamic(%arg0 : tensor, - %arg1 : tensor, %arg2 : tensor) -> tensor { + %arg1 : tensor, %arg2 : tensor, %sz0: index, %sz1: index, %sz2: index, %sz3: index, %sz4: index) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [%sz0, 7, %sz1, %sz2, 3, %sz3, 5, %sz4] : tensor into tensor %d0 = tensor.dim %arg1, %c2 : tensor %d2 = tensor.dim %arg2, %c2 : tensor @@ -203,8 +200,8 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor, } -> tensor return %generic : tensor } -// CHECK: func @fuse_by_collapsing_dynamic( -// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: func @fuse_by_collapsing_dynamic +// CHECK-SAME: (%[[ARG0:.+]]: tensor, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index) // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] @@ -224,8 +221,8 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor, #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> -func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -> tensor<2x5xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32> +func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 6, %sz0, 5] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "reduction", "parallel"]} @@ -240,7 +237,8 @@ func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) - // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK: func @fuse_reductions( // CHECK-SAME: %[[ARG0:.+]]: tensor<2x?x5xf32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<2x5xf32>) -> tensor<2x5xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x5xf32> +// CHECK-SAME: %[[SZ0:.+]]: index) -> tensor<2x5xf32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"] @@ -253,7 +251,7 @@ func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) - #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x3x4x5xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> %init = tensor.empty(): tensor<2x3x4x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map0], @@ -280,7 +278,7 @@ func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tenso #map1 = affine_map<(d0, d1, d2, d3) -> (d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2xf32>) -> tensor<2x4x3x5xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> %init = tensor.empty() : tensor<2x4x3x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map2], @@ -307,7 +305,7 @@ func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %ar #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x5xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> %init = tensor.empty() : tensor<2x5xf32> %1 = linalg.generic { indexing_maps = [#map0, #map1, #map2], @@ -335,8 +333,8 @@ func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tensor<2x3x4x5xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<6xf32> into tensor<2x3xf32> - %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<20xf32> into tensor<4x5xf32> + %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> + %1 = tensor.expand_shape %arg1 [[0, 1]] output_shape [4, 5] : tensor<20xf32> into tensor<4x5xf32> %init = tensor.empty() : tensor<2x3x4x5xf32> %2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], @@ -359,8 +357,8 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%{{.+}}: tensor<6x20xf32>) -// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}} -// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}} output_shape [6, 4, 5] +// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 3, 4, 5] // CHECK: return %[[RESHAPE2]] // CONTROL-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -375,14 +373,14 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens // CONTROL: %[[GENERIC:.+]] = linalg.generic // CONTROL-SAME: ins(%[[EXPAND]], %[[ARG1]] : // CONTROL-SAME: outs(%[[INIT_RESHAPE]] : -// CONTROL: %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} +// CONTROL: %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} output_shape [2, 3, 4, 5] // ----- // Corner case that isnt handled currently. #map = affine_map<(d0) -> (d0)> func.func @zero_D_test(%arg0: tensor) -> tensor<1xf32> { - %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor into tensor<1xf32> %init = tensor.empty() : tensor<1xf32> %1 = linalg.generic { indexing_maps = [#map, #map], @@ -404,8 +402,8 @@ func.func @zero_D_test(%arg0: tensor) -> tensor<1xf32> { #map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor +func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4x?x?x8xf32>, %sz0: index, %sz1: index) -> tensor<4x?x?x8xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor into tensor %1 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} @@ -419,10 +417,12 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4 } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @fuse_only_one_reassociation( -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-SAME: %[[ARG1:.+]]: tensor<4x?x?x8xf32> -// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: func @fuse_only_one_reassociation +// CHECK-SAME: (%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index) +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8] // CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} @@ -431,17 +431,20 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4 // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] : // CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] : -// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} -// CHECK: return %[[EXPAND_GENERIC]] +// CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32> +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32> +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index +// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32> +// CHECK: return %[[EXPANDED_3]] // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)> -func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor { +func.func @fold_non_consecutive_dims(%arg0 : tensor, %sz0: index, %sz1: index) -> tensor { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor into tensor %d0 = tensor.dim %0, %c0 : tensor %d1 = tensor.dim %0, %c2 : tensor %init = tensor.empty(%d1, %d0) : tensor @@ -465,10 +468,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor (d0, d1)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: func @fold_non_consecutive_dims( -// CHECK-SAME: %[[ARG0:.+]]: tensor) -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[INIT:.+]] = tensor.empty +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index) +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: %[[C8:.+]] = arith.constant 8 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]] +// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]]) // CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}} // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] @@ -487,8 +496,12 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index +// CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C4]] : index +// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor into tensor +// CHECK: return %[[EXPANDED_3]] // ----- @@ -496,10 +509,10 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor (d0, d2, d3, d1)> #map1 = affine_map<(d0, d1, d2, d3) -> ()> -func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor) -> tensor { +func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor, %sz0: index, %sz1: index) -> tensor { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor into tensor %init = tensor.empty() : tensor %1 = linalg.generic { indexing_maps = [#map0, #map1], @@ -519,8 +532,8 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor) -> te return %1 : tensor } // CHECK: func @no_fold_non_consecutive_reduction_dims( -// CHECK-SAME: %[[ARG0:.+]]: tensor) -// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index) +// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8] // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND_ARG0]] : // CHECK: return %[[GENERIC]] diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir index f1c729ef963ba..751ece37bc094 100644 --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -4,15 +4,19 @@ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-LABEL: func @reshape -// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor) +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor, %[[SZ0:.*]]: index) +// CHECK: %[[C112:.*]] = arith.constant 112 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor into tensor // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel"]} // CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor<16xf32>) outs(%[[RI]] : tensor) -// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor into tensor +// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor +// CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C112]] : index +// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor into tensor // CHECK: return %[[RR]] : tensor -func.func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor) -> tensor { - %0 = tensor.expand_shape %A [[0, 1], [2]] +func.func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor, %sz0: index) -> tensor { + %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16] : tensor into tensor %2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, @@ -39,13 +43,13 @@ func.func @reshape(%A: tensor, %B: tensor<16xf32>, %init: tensor, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>) -// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: return %[[RR]] : tensor<112x112x16xf32> func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, %C: tensor<16xf32>) -> tensor<112x112x16xf32> { - %0 = tensor.expand_shape %A [[0, 1], [2]] + %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> - %1 = tensor.expand_shape %B [[0, 1], [2]] + %1 = tensor.expand_shape %B [[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> %2 = tensor.empty() : tensor<112x112x16xf32> %3 = linalg.generic {indexing_maps = [ @@ -69,11 +73,11 @@ func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, // Negative test, since the second source is broadcasted from d1 we cannot merge // d0 and d1 dimensions // CHECK-LABEL: func @reshape_negative -// CHECK: tensor.expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32> +// CHECK: tensor.expand_shape {{.*}} {{\[\[}}0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> // CHECK: linalg.generic // CHECK: } -> tensor<112x112x16xf32> func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { - %20 = tensor.expand_shape %A [[0, 1], [2]] + %20 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> %21 = tensor.empty() : tensor<112x112x16xf32> %22 = linalg.generic {indexing_maps = [ @@ -96,7 +100,7 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>, %cst_6 = arith.constant 1.000000e+00 : f32 %cst_7 = arith.constant 7.000000e+00 : f32 %cst_8 = arith.constant 1.1920929E-7 : f32 - %25 = tensor.expand_shape %arg0 [[0, 1], [2]] + %25 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 5] : tensor<6x5xi32> into tensor<2x3x5xi32> %26 = tensor.empty() : tensor<2x3x5xf32> %28 = linalg.generic { diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir index ab948988b7b6e..0f0337a3604e0 100644 --- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir @@ -48,7 +48,7 @@ func.func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : te ^bb0(%arg2: f32): linalg.yield %cst : f32 } -> tensor - %0 = tensor.expand_shape %fill [[0, 1], [2]] : tensor into tensor<1x?x?xf32> + %0 = tensor.expand_shape %fill [[0, 1], [2]] output_shape [1, %d0, %d1] : tensor into tensor<1x?x?xf32> %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>) outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> return %1 : tensor<1x?x?xf32> diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index 342c067b5c4ba..f42666f81bbad 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -30,10 +30,20 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0], [1], [2, 3] -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0], [1], [2, 3] +// CHECK: %[[C4:.+]] = arith.constant 4 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM_1]], %[[C4]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor into tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_4]], %[[C4]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -50,7 +60,9 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor, #map1 = affine_map<(d0, d1) -> ()> func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor, %arg1 : tensor, - %arg2 : f32) -> + %arg2 : f32, + %sz0: index, + %sz1: index) -> tensor { %0 = linalg.generic { @@ -63,7 +75,7 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor, %2 = arith.addf %1, %arg5 : f32 linalg.yield %2 : f32 } -> tensor - %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, 4, %sz1, 5] : tensor into tensor return %1 : tensor } @@ -75,14 +87,22 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index +// CHECK: %[[C20:.+]] = arith.constant 20 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM_0]], %[[C20]] : index +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor into tensor +// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -94,7 +114,7 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor, // ----- func.func @reshape_as_consumer_permutation - (%a : tensor, %b : tensor) + (%a : tensor, %b : tensor, %sz0: index, %sz1: index, %sz2: index) -> tensor { %c = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, @@ -107,8 +127,7 @@ func.func @reshape_as_consumer_permutation %1 = arith.addf %arg0, %arg1 : f32 linalg.yield %1 : f32 } -> tensor - %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] - : tensor into tensor + %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor into tensor return %d : tensor } // CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> @@ -117,15 +136,27 @@ func.func @reshape_as_consumer_permutation // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3, 4], [5] -// CHECK-SAME: tensor into tensor<3x4x?x?x2x?xf32> -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-SAME: tensor into tensor<3x4x?x?xf32> -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1], [2], [3, 4, 5]] -// CHECK-SAME: tensor into tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index +// CHECK: %[[C12:.+]] = arith.constant 12 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C12]] : index +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C2]] : index +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor into tensor<3x4x?x?x2x?xf32> +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_2]], %[[C12]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor into tensor<3x4x?x?xf32> +// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_5]], %[[C2]] : index +// CHECK: %[[VAL_4:.+]] = arith.divui %[[DIM_7]], %[[C12]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] @@ -152,7 +183,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) %2 = arith.mulf %arg1, %arg2 : f32 linalg.yield %2 : f32 } -> tensor<264x4xf32> - %2 = tensor.expand_shape %1 [[0, 1], [2]] : + %2 = tensor.expand_shape %1 [[0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> return %2 : tensor<8x33x4xf32> } @@ -163,12 +194,8 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // CHECK-DAG: %[[CST:.+]] = arith.constant // CHECK-SAME: : tensor<8x33x4xf32> // CHECK-DAG: %[[INIT:.+]] = tensor.empty() -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1], [2] -// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] -// CHECK-SAME: [0, 1], [2] -// CHECK-SAME: : tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T2:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel"] @@ -232,7 +259,8 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, #map0 = affine_map<(d0, d1) -> (d0, d1)> func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, - %arg1 : tensor) -> + %arg1 : tensor, + %sz0: index, %sz1: index) -> tensor { %0 = linalg.generic { @@ -250,7 +278,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, %5 = arith.addi %3, %4 : i32 linalg.yield %5 : i32 } -> tensor - %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : tensor into tensor return %1 : tensor } @@ -302,8 +330,7 @@ func.func @reshape_as_consumer_permutation %7 = arith.addi %5, %6 : i32 linalg.yield %7 : i32 } -> tensor<6x4x210xi32> - %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] - : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> + %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> return %d : tensor<2x3x4x5x6x7xi32> } @@ -319,13 +346,9 @@ func.func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> // CHECK-DAG: %[[INIT:.+]] = tensor.empty() -// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3, 4], [5] -// CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-DAG: %[[T3:.+]] = tensor.expand_shape %[[INIT]] -// CHECK-SAME: [0, 1], [2], [3, 4, 5] -// CHECK-SAME: : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32> +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32> +// CHECK: %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> // CHECK: %[[T4:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) @@ -411,7 +434,8 @@ func.func @reshape_as_producer_projected_permutation( #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1, d0)> func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, - %arg1 : tensor) -> + %arg1 : tensor, + %sz0: index, %sz1: index) -> tensor { %0 = linalg.generic { @@ -423,7 +447,7 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, %1 = arith.mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : tensor into tensor return %1 : tensor } @@ -433,15 +457,22 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, // CHECK: func @generic_op_reshape_consumer_fusion_projected // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index +// CHECK: %[[C20:.+]] = arith.constant 20 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C20]] : index +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_1]], %[[C20]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor into tensor +// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -466,6 +497,7 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { } -> tensor return %3 : tensor } + // CHECK: func @no_fuse_dynamic_dims // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] @@ -503,7 +535,8 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor, %b : tensor) + (%a : tensor, %b : tensor, %sz0: index, + %sz1: index, %sz2: index, %sz3: index, %sz4: index) -> (tensor, tensor) { %c:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, @@ -517,10 +550,8 @@ func.func @reshape_as_consumer_permutation_with_multiple_results %1 = arith.addf %arg0, %arg1 : f32 linalg.yield %1, %1 : f32, f32 } -> (tensor, tensor) - %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] - : tensor into tensor - %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] - : tensor into tensor + %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor into tensor + %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] output_shape [%sz3, %sz4, 2, 3, 4, %sz2] : tensor into tensor return %d, %e : tensor, tensor } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> @@ -528,17 +559,40 @@ func.func @reshape_as_consumer_permutation_with_multiple_results // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)> // CHECK: func @reshape_as_consumer_permutation_with_multiple_results -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-DAG: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}} -// CHECK-DAG: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}} -// CHECK-DAG: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}} -// CHECK-DAG: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}} -// CHECK: %[[GENERIC:.+]]:2 = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] -// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : -// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : -// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index +// CHECK: %[[C12:.+]] = arith.constant 12 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C12]] : index +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C2]] : index +// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor into tensor<3x4x?x?x2x?xf32> +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_2]], %[[C12]] : index +// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor into tensor<3x4x?x?xf32> +// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_5]], %[[C2]] : index +// CHECK: %[[VAL_4:.+]] = arith.divui %[[DIM_7]], %[[C12]] : index +// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor into tensor +// CHECK: %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[VAL_5:.+]] = arith.divui %[[DIM_10]], %[[C2]] : index +// CHECK: %[[VAL_6:.+]] = arith.divui %[[DIM_11]], %[[C12]] : index +// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor into tensor +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : +// CHECK-SAME: outs(%[[RESHAPE2]], %[[RESHAPE3]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 // ----- @@ -556,7 +610,7 @@ module { %2 = arith.addf %arg4, %arg5 : f32 linalg.yield %2, %2 : f32, f32 } -> (tensor<512xf32>, tensor<200x512xf32>) - %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] : tensor<200x512xf32> into tensor<25x8x1x512xf32> + %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32> return %1 : tensor<25x8x1x512xf32> } } @@ -567,7 +621,7 @@ module { // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<512xf32> // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<200x512xf32> -// CHECK: %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[}}[0, 1, 2], [3]{{\]}} +// CHECK: %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[\[}}0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32> // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]], #[[MAP1]]] // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : @@ -581,7 +635,9 @@ module { #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> + %arg2 : tensor, + %sz0: index, + %sz1: index) -> tensor { %0 = linalg.generic { @@ -593,7 +649,7 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor, %1 = arith.mulf %arg3, %arg4 : f32 linalg.yield %1 : f32 } -> tensor - %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : tensor into tensor return %1 : tensor } @@ -605,12 +661,18 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1, 2], [3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index +// CHECK: %[[C20:.+]] = arith.constant 20 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C20]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"] @@ -650,10 +712,21 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1], [2], [3, 4] -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] -// CHECK-SAME: [0, 1], [2, 3] +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C7:.+]] = arith.constant 7 : index +// CHECK: %[[C8:.+]] = arith.constant 8 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C8]] : index +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C7]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index +// CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C7]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"] @@ -668,12 +741,14 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> + %arg2 : tensor, + %sz0: index, + %sz1: index) -> tensor { %0 = linalg.add ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor - %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] : + %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] : tensor into tensor return %1 : tensor } @@ -683,15 +758,22 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor -// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK-SAME: tensor into tensor +// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index +// CHECK: %[[C20:.+]] = arith.constant 20 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM_0]], %[[C20]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor into tensor +// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index +// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor into tensor // CHECK: %[[T4:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -721,10 +803,20 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] -// CHECK-SAME: [0, 1], [2, 3] -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] -// CHECK-SAME: [0, 1], [2, 3] +// CHECK: %[[C8:.+]] = arith.constant 8 : index +// CHECK: %[[C7:.+]] = arith.constant 7 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C7]] : index +// CHECK: %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C8]] : index +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor into tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C7]] : index +// CHECK: %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir index 4262cd23e7469..8fb84248c9613 100644 --- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -199,13 +199,12 @@ func.func @empty_tensor_dim_of_linalg_result(%arg_0 : tensor, // ----- -func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index) +func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (index, index, index) { %c1 = arith.constant 1 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index - %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> %1 = tensor.dim %0, %c1 : tensor<2x3x5x4x?x7xf32> %2 = tensor.dim %0, %c3 : tensor<2x3x5x4x?x7xf32> %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32> diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir index 006d6105677e9..31e9fd00cffa0 100644 --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -13,8 +13,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: @matmul_split // CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 4, 64] : tensor<16x256xf32> into tensor<16x4x64xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 64, 32] : tensor<256x32xf32> into tensor<4x64x32xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] @@ -65,7 +65,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: ten // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> //CHECK-LABEL: @generic_split_1d // CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [4, 8] : tensor<32xf32> into tensor<4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // CHECK: %[[G:.*]] = linalg.generic @@ -119,8 +119,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d // CHECK-DAG: %[[ID:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} @@ -177,8 +177,8 @@ func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32 // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d_ninf // CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} @@ -218,8 +218,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: @matmul_split // CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 64, 4] : tensor<16x256xf32> into tensor<16x64x4xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [64, 4, 32] : tensor<256x32xf32> into tensor<64x4x32xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] @@ -270,7 +270,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: ten // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> //CHECK-LABEL: @generic_split_1d // CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // CHECK: %[[G:.*]] = linalg.generic @@ -324,8 +324,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d // CHECK-DAG: %[[ID:.*]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} @@ -382,8 +382,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d // CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32 -// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> -// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir index 58d4b21ea2dd9..d7ff1ded9d933 100644 --- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir @@ -1710,10 +1710,12 @@ module attributes {transform.with_named_sequence} { #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: @not_vectorizable func.func @not_vectorizable(%arg0: tensor<1x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<1x128xf32> { + %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<1x128xf32> %1 = scf.for %arg5 = %arg2 to %arg1 step %arg3 iter_args(%arg6 = %0) -> (tensor<1x128xf32>) { %extracted_slice = tensor.extract_slice %arg6[0, 0] [1, %arg1] [1, 1] : tensor<1x128xf32> to tensor - %expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor into tensor<1x?xf32> + %sz0 = tensor.dim %extracted_slice, %c0 : tensor + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [1, %sz0] : tensor into tensor<1x?xf32> %extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor %extracted_slice_1 = tensor.extract_slice %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_0 : tensor) outs(%extracted_slice_1 : tensor) { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 506ed1f1c10b1..f442a61dc31ed 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -13,7 +13,7 @@ func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> { // CHECK-LABEL: expand_shape_identity_fold // CHECK-NEXT: return func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> { - %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8> + %0 = memref.expand_shape %arg0 [[0], [1]] output_shape [5, 4] : memref<5x4xi8> into memref<5x4xi8> return %0 : memref<5x4xi8> } @@ -23,7 +23,7 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> // CHECK-NEXT: return func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> { %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref - %1 = memref.expand_shape %0 [] : memref into memref<1x1xi8> + %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref into memref<1x1xi8> return %1 : memref<1x1xi8> } @@ -455,9 +455,9 @@ func.func @compose_collapse_of_collapse(%arg0 : memref) // ----- func.func @do_not_compose_collapse_of_expand_non_identity_layout( - %arg0: memref>) + %arg0: memref>, %sz0: index, %sz1: index) -> memref> { - %1 = memref.expand_shape %arg0 [[0, 1], [2]] : + %1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] : memref> into memref> %2 = memref.collapse_shape %1 [[0, 1, 2]] : @@ -471,35 +471,34 @@ func.func @do_not_compose_collapse_of_expand_non_identity_layout( // ----- -func.func @compose_expand_of_expand(%arg0 : memref) +func.func @compose_expand_of_expand(%arg0 : memref, %sz0: index, %sz1: index, %sz2: index, %sz3: index) -> memref { - %0 = memref.expand_shape %arg0 [[0, 1], [2]] + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] : memref into memref - %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] - : memref into memref + %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%sz2, 6, 4, 5, %sz3] : memref into memref return %1 : memref } // CHECK-LABEL: func @compose_expand_of_expand -// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%{{.*}}, 6, 4, 5, %{{.*}}] // CHECK-NOT: memref.expand_shape // ----- func.func @compose_expand_of_expand_of_zero_dim(%arg0 : memref) -> memref<1x1x1xf32> { - %0 = memref.expand_shape %arg0 [] : memref into memref<1xf32> - %1 = memref.expand_shape %0 [[0, 1, 2]] + %0 = memref.expand_shape %arg0 [] output_shape [1] : memref into memref<1xf32> + %1 = memref.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1] : memref<1xf32> into memref<1x1x1xf32> return %1 : memref<1x1x1xf32> } // CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim -// CHECK: memref.expand_shape %{{.*}} [] +// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1, 1] // CHECK-SAME: memref into memref<1x1x1xf32> // ----- func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { - %0 = memref.expand_shape %arg0 [[0, 1], [2]] + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4] : memref<12x4xf32> into memref<3x4x4xf32> %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref<3x4x4xf32> into memref<12x4xf32> @@ -510,9 +509,9 @@ func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> // ----- -func.func @fold_collapse_collapse_of_expand(%arg0 : memref) +func.func @fold_collapse_collapse_of_expand(%arg0 : memref, %sz0: index, %sz1: index) -> memref { - %0 = memref.expand_shape %arg0 [[0, 1], [2]] + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] : memref into memref %1 = memref.collapse_shape %0 [[0, 1], [2]] : memref into memref @@ -525,7 +524,7 @@ func.func @fold_collapse_collapse_of_expand(%arg0 : memref) func.func @fold_memref_expand_cast(%arg0 : memref) -> memref<2x4x4xf32> { %0 = memref.cast %arg0 : memref to memref<8x4xf32> - %1 = memref.expand_shape %0 [[0, 1], [2]] + %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 4, 4] : memref<8x4xf32> into memref<2x4x4xf32> return %1 : memref<2x4x4xf32> } @@ -981,10 +980,10 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{ // CHECK-SAME: %[[m:.*]]: memref, 3> // CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref, 3> to memref, 3>) +func.func @collapse_expand_fold_to_cast(%m: memref, 3>, %sz0: index) -> (memref) { - %0 = memref.expand_shape %m [[0, 1]] + %0 = memref.expand_shape %m [[0, 1]] output_shape [1, %sz0] : memref, 3> into memref<1x?xf32, 3> %1 = memref.collapse_shape %0 [[0, 1]] : memref<1x?xf32, 3> into memref diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 0705b30ca45d8..3bd6b7c1fd791 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -421,10 +421,11 @@ func.func @simplify_expand_shape( %base: memref>, %offset0: index, %offset1: index, %offset2: index, %size0: index, %size1: index, %size2: index, - %stride0: index, %stride1: index, %stride2: index) + %stride0: index, %stride1: index, %stride2: index, + %sz0: index, %sz1: index) -> memref> { - %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] : + %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] : memref> into memref> @@ -491,7 +492,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static( index, index, index, index, index, index, index, index, index, index) { - %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] : + %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] output_shape [3, 5, 2, 2, 2] : memref<30x4xi16> into memref<3x5x2x2x2xi16> %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : @@ -595,12 +596,13 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic( %base: memref>, %offset0: index, %offset1: index, %offset2: index, %size0: index, %size1: index, %size2: index, - %stride0: index, %stride1: index, %stride2: index) + %stride0: index, %stride1: index, %stride2: index, + %sz0: index, %sz1: index) -> (memref, index, index, index, index, index, index, index, index, index, index, index, index, index, index, index, index, index) { - %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] : + %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] : memref> into memref> @@ -643,7 +645,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank( index, index, index, index, index, index, index, index, index, index) { - %expand_shape = memref.expand_shape %arg[] : + %expand_shape = memref.expand_shape %arg[] output_shape [1, 1, 1, 1, 1] : memref> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : @@ -1456,6 +1458,7 @@ func.func @extract_strided_metadata_of_cast_w_csts( index, index, index, index } + // ----- // Check that we don't simplify extract_strided_metadata of @@ -1497,6 +1500,7 @@ func.func @extract_strided_metadata_of_cast_unranked( // ----- + memref.global "private" @dynamicShmem : memref<0xf16,3> // CHECK-LABEL: func @zero_sized_memred diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 5b853a6cc5a37..254cd4015eed9 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -412,7 +412,7 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 { func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 { - %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32> + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32> %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32> return %1 : f32 } @@ -458,7 +458,7 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 { func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 { - %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32> + %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32> %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32> return %1 : f32 } @@ -469,15 +469,17 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar // ----- // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape -func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 { +// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index) +func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { %c0 = arith.constant 0 : index - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> return %0 : f32 } -// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> -// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]] -// CHECK: return %[[LOAD]] +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> +// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> +// CHECK: return %[[VAL_0]] : f32 // ----- @@ -486,7 +488,7 @@ func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memr // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { - %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> affine.for %arg3 = 0 to 1 { affine.for %arg4 = 0 to 1024 { affine.for %arg5 = 0 to 1020 { @@ -515,7 +517,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { - %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> affine.for %arg3 = 0 to 1 { affine.for %arg4 = 0 to 1024 { affine.for %arg5 = 0 to 1020 { @@ -544,7 +546,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { - %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32> %cst = arith.constant 0 : index affine.for %arg3 = 0 to 1 { affine.for %arg4 = 0 to 1024 { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 1aef417549d9a..70c96aad9555e 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -392,9 +392,9 @@ func.func @copy_different_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) { // ----- -func.func @expand_shape(%arg0: memref) { +func.func @expand_shape(%arg0: memref, %sz0: index, %sz1: index) { // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}} - %0 = memref.expand_shape %arg0 [[0, 1]] : memref into memref + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 5, %sz1] : memref into memref return } @@ -402,7 +402,15 @@ func.func @expand_shape(%arg0: memref) { func.func @expand_shape(%arg0: memref) { // expected-error @+1 {{rank 0 memrefs can only be extended/collapsed with/from ones}} - %0 = memref.expand_shape %arg0 [] : memref into memref<1x2xf32> + %0 = memref.expand_shape %arg0 [] output_shape [1, 2] : memref into memref<1x2xf32> + return +} + +// ----- + +func.func @expand_shape_illegal_output_shape(%arg0: memref<2xf32>) { + // expected-error @+1 {{expected number of static shape bounds to be equal to the output rank (3) but found 2 inputs instead}} + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 2] : memref<2xf32> into memref<1x1x2xf32> return } @@ -415,9 +423,9 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref) { // ----- -func.func @expand_shape_out_of_bounds(%arg0: memref) { +func.func @expand_shape_out_of_bounds(%arg0: memref, %sz0: index) { // expected-error @+1 {{op reassociation index 2 is out of bounds}} - %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref into memref<4x?xf32> + %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [4, %sz0] : memref into memref<4x?xf32> } // ----- @@ -425,7 +433,7 @@ func.func @expand_shape_out_of_bounds(%arg0: memref) { func.func @expand_shape_invalid_result_layout( %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) { // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}} - %0 = memref.expand_shape %arg0 [[0, 1], [2]] : + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 20] : memref<30x20xf32, strided<[4000, 2], offset: 100>> into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>> } @@ -462,7 +470,7 @@ func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref) // like this. Verify that a sensible error is emitted in this case. func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) { // expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}} - %0 = memref.expand_shape %arg0 [[0], [1], [1]] : + %0 = memref.expand_shape %arg0 [[0], [1], [1]] output_shape [2, 3] : memref<2x3x1xf32> into memref<2x3xf32> } @@ -495,20 +503,10 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref) { // ----- -func.func @expand_shape_illegal_dynamic_memref - (%arg0: memref) -> memref { - // expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}} - %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] - : memref into memref - return %0 : memref -} - -// ----- - func.func @expand_shape_illegal_static_memref (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> { // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}} - %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] + %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5] : memref<2x3x20xf32> into memref<2x3x2x4x5xf32> return %0 : memref<2x3x2x4x5xf32> } @@ -525,30 +523,30 @@ func.func @collapse_shape_illegal_static_memref // ----- -func.func @expand_shape_illegal_mixed_memref(%arg0 : memref) +func.func @expand_shape_illegal_mixed_memref(%arg0 : memref, %sz0: index) -> memref { // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}} - %0 = memref.expand_shape %arg0 [[0, 1], [2]] + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5] : memref into memref return %0 : memref } // ----- -func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref) +func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref, %sz0: index) -> memref { // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}} - %0 = memref.expand_shape %arg0 [[0], [1, 2]] + %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5] : memref into memref return %0 : memref } // ----- -func.func @expand_shape_invalid_static_dim_size(%arg0 : memref) +func.func @expand_shape_invalid_static_dim_size(%arg0 : memref, %sz0: index) -> memref { // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}} - %0 = memref.expand_shape %arg0 [[0], [1, 2]] + %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5] : memref into memref return %0 : memref } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 2d69904f27db5..60fb0ffeee240 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -106,9 +106,9 @@ func.func @expand_collapse_shape_static( %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x4x5xf32> into memref<12x5xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [3, 4, 5] // CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> - %r0 = memref.expand_shape %0 [[0, 1], [2]] : + %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 5] : memref<12x5xf32> into memref<3x4x5xf32> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]] @@ -116,9 +116,9 @@ func.func @expand_collapse_shape_static( %1 = memref.collapse_shape %arg0 [[0], [1, 2]] : memref<3x4x5xf32> into memref<3x20xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [3, 4, 5] // CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> - %r1 = memref.expand_shape %1 [[0], [1, 2]] : + %r1 = memref.expand_shape %1 [[0], [1, 2]] output_shape [3, 4, 5] : memref<3x20xf32> into memref<3x4x5xf32> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]] @@ -126,29 +126,29 @@ func.func @expand_collapse_shape_static( %2 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<3x4x5xf32> into memref<60xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] output_shape [3, 4, 5] // CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> - %r2 = memref.expand_shape %2 [[0, 1, 2]] : + %r2 = memref.expand_shape %2 [[0, 1, 2]] output_shape [3, 4, 5] : memref<60xf32> into memref<3x4x5xf32> -// CHECK: memref.expand_shape {{.*}} [] +// CHECK: memref.expand_shape {{.*}} [] output_shape [1, 1] // CHECK-SAME: memref into memref<1x1xf32> - %r5 = memref.expand_shape %arg5 [] : + %r5 = memref.expand_shape %arg5 [] output_shape [1, 1] : memref into memref<1x1xf32> // Reshapes with a custom layout map. -// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] - %l0 = memref.expand_shape %arg3 [[0], [1, 2]] : +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [30, 4, 5] + %l0 = memref.expand_shape %arg3 [[0], [1, 2]] output_shape [30, 4, 5] : memref<30x20xf32, strided<[4000, 2], offset: 100>> into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] - %l1 = memref.expand_shape %arg3 [[0, 1], [2]] : +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [2, 15, 20] + %l1 = memref.expand_shape %arg3 [[0, 1], [2]] output_shape [2, 15, 20] : memref<30x20xf32, strided<[4000, 2], offset: 100>> into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] - %r4 = memref.expand_shape %arg4 [[0], [1, 2]] : +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [1, 1, 5] + %r4 = memref.expand_shape %arg4 [[0], [1, 2]] output_shape [1, 1, 5] : memref<1x5xf32, strided<[5, 1], offset: ?>> into memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>> @@ -164,9 +164,9 @@ func.func @expand_collapse_shape_static( memref<2049xi64, strided<[?], offset: ?>> // Reshapes that expand and collapse back a contiguous buffer with some 1's. -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] // CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> - %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] : + %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] @@ -176,15 +176,18 @@ func.func @expand_collapse_shape_static( // Reshapes on tensors. // CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> - %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] : + %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] : tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> // CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] : tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> +// CHECK: tensor.dim %arg2, {{.*}} : tensor<3x?x5xf32> // CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> - %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] : + %c1 = arith.constant 1 : index + %sz1 = tensor.dim %arg2, %c1 : tensor<3x?x5xf32> + %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] output_shape [1, 3, %sz1, 1, 5] : tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> // CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> @@ -197,15 +200,18 @@ func.func @expand_collapse_shape_static( func.func @expand_collapse_shape_dynamic(%arg0: memref, %arg1: memref>, %arg2: memref>, - %arg3: memref>) { + %arg3: memref>, + %arg4: index, + %arg5: index, + %arg6: index) { // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5] // CHECK-SAME: memref into memref - %r0 = memref.expand_shape %0 [[0, 1], [2]] : + %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] : memref into memref // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] @@ -214,9 +220,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, memref> into memref> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5] // CHECK-SAME: memref> into memref> - %r1 = memref.expand_shape %1 [[0, 1], [2]] : + %r1 = memref.expand_shape %1 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] : memref> into memref> @@ -226,9 +232,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, memref> into memref> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5] // CHECK-SAME: memref> into memref> - %r2 = memref.expand_shape %2 [[0, 1], [2]] : + %r2 = memref.expand_shape %2 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] : memref> into memref> @@ -238,9 +244,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, memref> into memref> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42] // CHECK-SAME: memref> into memref - %r3 = memref.expand_shape %3 [[0, 1]] : + %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] : memref> into memref return } @@ -248,12 +254,12 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref, func.func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) { %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref - %1 = memref.expand_shape %0 [] : memref into memref<1x1xf32> + %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref into memref<1x1xf32> return %0, %1 : memref, memref<1x1xf32> } // CHECK-LABEL: func @expand_collapse_shape_zero_dim // CHECK: memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref -// CHECK: memref.expand_shape %{{.*}} [] : memref into memref<1x1xf32> +// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1] : memref into memref<1x1xf32> func.func @collapse_shape_to_dynamic (%arg0: memref) -> memref { @@ -270,16 +276,18 @@ func.func @collapse_shape_to_dynamic // CHECK-LABEL: func @expand_collapse_shape_transposed_layout func.func @expand_collapse_shape_transposed_layout( %m0: memref>, - %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>) { + %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>, + %sz0: index, + %sz1: index) { - %r0 = memref.expand_shape %m0 [[0], [1, 2]] : + %r0 = memref.expand_shape %m0 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] : memref> into memref> %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] : memref> into memref> - %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] : + %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] : memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] : diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir index 4d7fcf6ac7cbb..28777a3e88672 100644 --- a/mlir/test/Dialect/MemRef/runtime-verification.mlir +++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir @@ -2,13 +2,14 @@ // CHECK-LABEL: func @expand_shape( // CHECK-SAME: %[[m:.*]]: memref +// CHECK-SAME: %[[sz0:.*]]: index // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] // CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]] // CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]] // CHECK: cf.assert %[[cmpi]], "ERROR: Runtime op verification failed -func.func @expand_shape(%m: memref) -> memref { - %0 = memref.expand_shape %m [[0, 1]] : memref into memref +func.func @expand_shape(%m: memref, %sz0: index) -> memref { + %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz0, 5] : memref into memref return %0 : memref } diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir index edb53fa024c26..c96f9c31443db 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir @@ -12,7 +12,7 @@ // // CHECK-ROUND-LABEL: func.func @sparse_expand( // CHECK-ROUND-SAME: %[[A:.*]]: tensor<100xf64, #sparse{{[0-9]*}}>) -> tensor<10x10xf64, #sparse{{[0-9]*}}> -// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}> +// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [10, 10] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}> // CHECK-ROUND: return %[[E]] : tensor<10x10xf64, #sparse{{[0-9]*}}> // // CHECK-LABEL: func.func @sparse_expand( @@ -39,7 +39,7 @@ // CHECK: return %[[NT1]] : tensor<10x10xf64, #sparse{{[0-9]*}}> // func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> { - %0 = tensor.expand_shape %arg0 [[0, 1]] : + %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [10, 10] : tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix> return %0 : tensor<10x10xf64, #SparseMatrix> } @@ -94,8 +94,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10 // roundtrip: // // CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand( -// CHECK-ROUND-SAME: %[[A:.*]]: tensor) -> tensor -// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor into tensor +// CHECK-ROUND-SAME: %[[A:.*]]: tensor, %[[SZ0:.*]]: index) -> tensor +// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [%[[SZ0]], 10] : tensor into tensor // CHECK-ROUND: return %[[E]] : tensor // // CHECK-LABEL: func.func @dynamic_sparse_expand( @@ -127,8 +127,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10 // CHECK-NOT: sparse_tensor.convert // CHECK: return %[[NT1]] : tensor // -func.func @dynamic_sparse_expand(%arg0: tensor) -> tensor { - %0 = tensor.expand_shape %arg0 [[0, 1]] : +func.func @dynamic_sparse_expand(%arg0: tensor, %sz0: index) -> tensor { + %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 10] : tensor into tensor return %0 : tensor } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 815bc383af95a..4f553adcc500f 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -367,11 +367,14 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x // CHECK-LABEL: func @tensor.expand_shape( // CHECK-SAME: %[[t1:.*]]: tensor -func.func @tensor.expand_shape(%t1: tensor) -> tensor<2x?x10xf32> { +func.func @tensor.expand_shape(%t1: tensor, %sz0: index) -> tensor<2x?x10xf32> { // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref - // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [ - // CHECK-SAME: [0, 1], [2]] : memref into memref<2x?x10xf32> - %0 = tensor.expand_shape %t1 [[0, 1], [2]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C2]] : index + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref into memref<2x?x10xf32> + %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10] : tensor into tensor<2x?x10xf32> // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] @@ -384,14 +387,15 @@ func.func @tensor.expand_shape(%t1: tensor) -> tensor<2x?x10xf32> { // CHECK-LABEL: func @tensor.expand_shape_of_slice( // CHECK-SAME: %[[t1:.*]]: tensor func.func @tensor.expand_shape_of_slice( - %t1: tensor, %o1: index, %s1: index) -> tensor { + %t1: tensor, %o1: index, %s1: index, %sz0: index) -> tensor { // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref to memref> %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] : tensor to tensor - // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [ - // CHECK-SAME: [0, 1], [2, 3]] : memref> into memref> - %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] : + // CHECK: %[[C7:.*]] = arith.constant 7 : index + // CHECK: %[[VAL_1:.*]] = arith.divui %{{.*}}, %[[C7]] : index + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref> into memref> + %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] : tensor into tensor // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] // CHECK: return %[[r]] @@ -407,8 +411,8 @@ func.func @tensor.expand_shape_of_scalar_slice( // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref to memref> %0 = tensor.extract_slice %t1[%o1][1][1] : tensor to tensor - // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref into memref<1xf32, strided<[1], offset: ?>> - %1 = tensor.expand_shape %0 [] : tensor into tensor<1xf32> + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref into memref<1xf32, strided<[1], offset: ?>> + %1 = tensor.expand_shape %0 [] output_shape [1] : tensor into tensor<1xf32> // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] // CHECK: return %[[r]] return %1 : tensor<1xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 9a4dd2f3b5cc1..6177fe3c752c9 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: expand_shape_identity_fold // CHECK-NEXT: return func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> { - %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32> + %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32> return %0 : tensor<5xf32> } @@ -13,7 +13,7 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> { // CHECK-LABEL: expand_shape_rank0_identity_fold // CHECK-NEXT: return func.func @expand_shape_rank0_identity_fold(%arg0 : tensor) -> tensor { - %0 = tensor.expand_shape %arg0 [] : tensor into tensor + %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor into tensor return %0 : tensor } @@ -1051,29 +1051,28 @@ func.func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4 // ----- -func.func @compose_expand_of_expand(%arg0 : tensor) +func.func @compose_expand_of_expand(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> tensor { - %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] : tensor into tensor - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] - : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor into tensor return %1 : tensor } // CHECK-LABEL: compose_expand_of_expand -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] // CHECK-NOT: tensor.expand_shape // ----- func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor) -> tensor<1x1x1xf32> { - %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> - %1 = tensor.expand_shape %0 [[0, 1, 2]] + %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor into tensor<1xf32> + %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1] : tensor<1xf32> into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } // CHECK-LABEL: compose_expand_of_expand_of_zero_dim -// CHECK: tensor.expand_shape %{{.*}} [] +// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1] // CHECK-SAME: tensor into tensor<1x1x1xf32> // ----- @@ -1093,7 +1092,7 @@ func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor { // ----- func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4] : tensor<12x4xf32> into tensor<3x4x4xf32> %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<3x4x4xf32> into tensor<12x4xf32> @@ -1104,9 +1103,9 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> // ----- -func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor) +func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) -> tensor { - %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] : tensor into tensor %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor into tensor @@ -1121,7 +1120,7 @@ func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> - %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] + %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8] : tensor<40320xf32> into tensor<24x5x42x8xf32> return %1 : tensor<24x5x42x8xf32> } @@ -1137,7 +1136,7 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<24x5x42x8xf32> into tensor<40320xf32> - %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] + %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8] : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> return %1 : tensor<2x3x4x5x6x7x8xf32> } @@ -1149,16 +1148,16 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>) // ----- -func.func @compose_collapse_of_expand(%arg : tensor) +func.func @compose_collapse_of_expand(%arg : tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor { - %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] + %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1] : tensor into tensor %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] : tensor into tensor return %1 : tensor } // CHECK-LABEL: func @compose_collapse_of_expand -// CHECK: (%[[ARG:.*]]: tensor) +// CHECK: (%[[ARG:.*]]: tensor, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) // CHECK-NEXT: tensor.collapse_shape %[[ARG]] // CHECK-SAME: [0, 1], [2] // CHECK-SAME: : tensor into tensor @@ -1167,14 +1166,14 @@ func.func @compose_collapse_of_expand(%arg : tensor) func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512] : tensor<2048xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<1x4x1x512xf32> into tensor<4x512xf32> return %1 : tensor<4x512xf32> } // CHECK: func @compose_collapse_of_expand_1D -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512] // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> // ----- @@ -1183,14 +1182,14 @@ func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32> -> tensor<1x1x1x1xf32> { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1x1xf32> into tensor - %1 = tensor.expand_shape %0 [] + %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1] : tensor into tensor<1x1x1x1xf32> return %1 : tensor<1x1x1x1xf32> } // CHECK: func @compose_expand_of_collapse_0_rank_to_expand // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32> // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1], [2, 3] +// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1] // CHECK: return %[[RESULT]] // ----- @@ -1199,7 +1198,7 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x -> tensor<1x1x1xf32> { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1x1x1xf32> into tensor - %1 = tensor.expand_shape %0 [] + %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1] : tensor into tensor<1x1x1xf32> return %1 : tensor<1x1x1xf32> } @@ -1214,8 +1213,8 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x // CHECK-LABEL: func @zero_rank_reshape_multi func.func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> - %1 = tensor.expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> + %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor into tensor<1xf32> + %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32> %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor return %2 : tensor } @@ -1250,7 +1249,7 @@ func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>) // ----- func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] + %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512] : tensor<4x512xf32> into tensor<1x4x1x512xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]] : tensor<1x4x1x512xf32> into tensor<2048xf32> @@ -1264,42 +1263,40 @@ func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048x func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] - : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]] : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> return %1 : tensor<4x512x1x1xf32> } // CHECK: func @fold_collapse_of_expand_unit_dims -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1] // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> // ----- func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] - : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> + %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } // CHECK: func @compose_collapse_of_expand_unit_dims -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4] // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> // ----- func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2]] + %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } // CHECK: func @compose_collapse_of_expand_trailing_unit_dims -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- @@ -1321,14 +1318,13 @@ func.func @compose_collapse_of_collapse_unit_dims_dynamic( func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1, 2]] - : tensor<2xf32> into tensor<2x1x1xf32> + %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2]] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } // CHECK: func @fold_collapse_of_expand_trailing_unit_dims -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1] // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> // ----- @@ -1349,8 +1345,7 @@ func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic( func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] - : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> + %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]] : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> return %1 : tensor<12x42xf32> @@ -1361,9 +1356,9 @@ func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf3 // ----- -func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor) +func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor, %sz0: index, %sz1: index, %sz2: index) -> tensor { - %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor @@ -1378,7 +1373,7 @@ func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> { - %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8] : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]] : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> @@ -1392,7 +1387,7 @@ func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>) func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : tensor<3x2x2x1xf32> into tensor<12x1xf32> @@ -1401,7 +1396,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) // CHECK: func @no_fold_collapse_of_expand_empty_expr // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1], [2, 3] +// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1] // CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]] // CHECK-SAME: [0, 1, 2], [3] // CHECK: return %[[RES:.+]] : tensor<12x1xf32> @@ -1410,7 +1405,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>) func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { %c0 = arith.constant dense<42> : tensor<2x8xi32> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] : tensor<2x8xi32> into tensor<2x4x2xi32> return %0 : tensor<2x4x2xi32> } @@ -1421,7 +1416,7 @@ func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> { // ----- func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> { %c0 = tensor.splat %arg : tensor<2x4xf32> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2] : tensor<2x4xf32> into tensor<2x2x2xf32> return %0 : tensor<2x2x2xf32> } @@ -1434,13 +1429,12 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> { // ----- // CHECK-LABEL: @expand_shape_splat_dynamic_no_fold -// CHECK-SAME: %[[F:.+]]: f32 -// CHECK-SAME: %[[M:.+]]: index -func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> { - // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] +// CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index) +func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> { + // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32> // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]] %c0 = tensor.splat %arg[%m] : tensor<2x?xf32> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32> return %0 : tensor<2x2x?xf32> } @@ -1475,7 +1469,7 @@ func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { %c0 = arith.constant dense<42> : tensor<2x8xi16> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] : tensor<2x8xi16> into tensor<2x4x2xi16> return %0 : tensor<2x4x2xi16> } @@ -1488,7 +1482,7 @@ func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> { %c0 = arith.constant dense<42.0> : tensor<2x8xf32> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] : tensor<2x8xf32> into tensor<2x4x2xf32> return %0 : tensor<2x4x2xf32> } @@ -1501,7 +1495,7 @@ func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> { func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> { %c0 = arith.constant dense<42.0> : tensor<2x8xf64> - %0 = tensor.expand_shape %c0 [[0], [1, 2]] + %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2] : tensor<2x8xf64> into tensor<2x4x2xf64> return %0 : tensor<2x4x2xf64> } @@ -1851,7 +1845,7 @@ func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> { // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32> // CHECK: return %[[FROM]] : tensor<1xi32> %0 = tensor.from_elements %arg0 : tensor - %1 = tensor.expand_shape %0 [] : tensor into tensor<1xi32> + %1 = tensor.expand_shape %0 [] output_shape [1] : tensor into tensor<1xi32> return %1 : tensor<1xi32> } @@ -2073,9 +2067,9 @@ func.func @empty_tensor_canonicalize(%i : index) { // CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] // CHECK: return %[[apply]] -func.func @dim_of_expand_shape(%t: tensor) -> index { +func.func @dim_of_expand_shape(%t: tensor, %sz0: index, %sz1: index) -> index { %c2 = arith.constant 2 : index - %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] + %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8] : tensor into tensor %1 = tensor.dim %0, %c2 : tensor return %1 : index @@ -2107,9 +2101,9 @@ func.func @dim_of_collapse_shape(%t: tensor) -> index { // CHECK-LABEL: func @collapse_expand_fold_to_cast( // CHECK-SAME: %[[t:.*]]: tensor // CHECK: return %[[t]] -func.func @collapse_expand_fold_to_cast(%t: tensor) -> (tensor) +func.func @collapse_expand_fold_to_cast(%t: tensor, %sz0: index) -> (tensor) { - %0 = tensor.expand_shape %t [[0, 1]] : tensor into tensor<1x?xf32> + %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor into tensor<1x?xf32> %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index 15f841f2128ed..e200a4f892613 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -13,10 +13,9 @@ module attributes {transform.with_named_sequence} { // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)> -func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { +func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> { %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> return %1 : tensor<2x3x5x4x?x7xf32> } // CHECK-LABEL: func @empty_reshape_expansion diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir index 625408dfefe21..d3ac6ce792f36 100644 --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -11,9 +11,11 @@ func.func @expand_shape_of_rank_reducing_extract( { %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1] : tensor to tensor - %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] + %c0 = arith.constant 0 : index + %sz0 = tensor.dim %0, %c0 : tensor + %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5] : tensor into tensor - %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] + %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5] : tensor into tensor return %1, %2 : tensor, tensor } diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 79ca0de68a1e9..41b6529f64afa 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -273,21 +273,10 @@ func.func @insert_slice_wrong_dynamic_type(%t1: tensor, %t2: tensor<8 // ----- -func.func @illegal_expanding_reshape_dynamic_tensor - (%arg0: tensor) -> tensor { - // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} - %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] - : tensor into tensor - return %0 : tensor -} - -// ----- - - func.func @illegal_expanding_reshape_static_tensor (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> { // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} - %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] + %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5] : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32> return %0 : tensor<2x3x2x4x5xf32> } @@ -304,24 +293,33 @@ func.func @illegal_collapsing_reshape_static_tensor // ----- -func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor) +func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor, %sz0: index) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} - %0 = tensor.expand_shape %arg0 [[0, 1], [2]] + %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5] : tensor into tensor return %0 : tensor } // ----- -func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor) +func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor, %sz0: index) -> tensor { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} - %0 = tensor.expand_shape %arg0 [[0], [1, 2]] + %0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5] : tensor into tensor return %0 : tensor } +// ----- + +func.func @expand_shape_illegal_output_shape(%arg0: tensor<2xf32>) { + // expected-error @+1 {{expected number of static shape dims to be equal to the output rank (3) but found 2 inputs instead}} + %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [1, 2] : tensor<2xf32> into tensor<1x1x2xf32> + return +} + + // ----- func.func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor) -> tensor { diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 2b0a74acce082..378137a14b59f 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -194,12 +194,26 @@ func.func @insert_slice( func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor - %1 = tensor.expand_shape %0 [] : tensor into tensor<1x1xf32> + %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor into tensor<1x1xf32> return %0, %1 : tensor, tensor<1x1xf32> } // CHECK-LABEL: func @tensor_reshape_zero_dim // CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor -// CHECK: tensor.expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> +// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor into tensor<1x1xf32> + +// ----- + +func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor, %sz0 : index, %sz1 : index, %sz2 : index) + -> (tensor<5x?x?x?xf32>) { + %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor into tensor<5x?x?x?xf32> + return %1 : tensor<5x?x?x?xf32> +} + +// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> { +// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor into tensor<5x?x?x?xf32> +// CHECK: return %expanded : tensor<5x?x?x?xf32> +// CHECK: } + // ----- diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir index 9948c0246e6ed..5a2eade0ecccf 100644 --- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @single_dim_packing( // CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>) -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32> // CHECK: return %[[EXPANDED]] : tensor<8x32xf32> func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> { %empty = tensor.empty() : tensor<8x32xf32> @@ -27,7 +27,7 @@ func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x3 // CHECK-LABEL: func.func @single_last_inner_dim_packing( // CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>) -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32> // CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32> func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> { %empty = tensor.empty() : tensor<5x8x32xf32> @@ -39,7 +39,7 @@ func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8 // CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm( // CHECK-SAME: %[[ARG0:.+]]: tensor<64xf32>) -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<64xf32> into tensor<2x32xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32> // CHECK: return %[[EXPANDED]] : tensor<2x32xf32> func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> { %empty = tensor.empty() : tensor<2x32xf32> @@ -51,7 +51,7 @@ func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf3 // CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm( // CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>) -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32> // CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32> func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> { %empty = tensor.empty() : tensor<5x8x32xf32> @@ -85,7 +85,7 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x // CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1] // CHECK: return %[[EXPANDED]] func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> { %empty = tensor.empty() : tensor<1x32x1x1xf32> @@ -98,7 +98,7 @@ func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf3 // CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2] // CHECK: return %[[EXPANDED]] func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> { %empty = tensor.empty() : tensor<1x16x1x2xf32> @@ -111,7 +111,7 @@ func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf3 // CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1] // CHECK: return %[[EXPANDED]] func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> { %empty = tensor.empty() : tensor<1x16x2x1xf32> diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 52c874c344c5e..acd2d3a14d741 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -12635,6 +12635,7 @@ cc_library( deps = [ ":ArithDialect", ":ComplexDialect", + ":DialectUtils", ":IR", "//llvm:Support", ],