diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 35d2b6007c628..8b10c00008865 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1253,7 +1253,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { inferCollapsedType(RankedTensorType type, ArrayRef reassociation); static RankedTensorType inferCollapsedType(RankedTensorType type, - SmallVector reassociation); + ArrayRef reassociation); }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 8c331f90f8a0d..72acd02d0d13d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1900,11 +1900,9 @@ FailureOr mlir::linalg::collapseOpIterationDims( applyPermutationMap(indexingMap, ArrayRef(loopBound)); Value result; if (isa(collapsedOpResult.getType())) { - MemRefType expandShapeResultType = MemRefType::get( - originalResultType.getShape(), originalResultType.getElementType()); result = memref::ExpandShapeOp::create( - rewriter, loc, expandShapeResultType, collapsedOpResult, - reassociation, resultShape); + rewriter, loc, originalResultType, collapsedOpResult, reassociation, + resultShape); } else { result = tensor::ExpandShapeOp::create( rewriter, loc, originalResultType, collapsedOpResult, reassociation, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 204e9bb73e12c..afd5414a190e5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1985,7 +1985,7 @@ SmallVector ExpandShapeOp::getReassociationExprs() { } RankedTensorType CollapseShapeOp::inferCollapsedType( - RankedTensorType type, SmallVector reassociation) { + RankedTensorType type, ArrayRef reassociation) { return inferCollapsedType( type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( type.getContext(), reassociation))); @@ -2023,10 +2023,11 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type, void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, ArrayRef reassociation, ArrayRef attrs) { - auto resultType = inferCollapsedType( - llvm::cast(src.getType()), - getSymbolLessAffineMaps( - convertReassociationIndicesToExprs(b.getContext(), reassociation))); + auto srcType = llvm::cast(src.getType()); + RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation); + auto resultType = + RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(), + srcType.getEncoding()); result.addAttribute(getReassociationAttrStrName(), getReassociationIndicesAttribute(b, reassociation)); build(b, result, resultType, src, attrs); diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir index c8b03f8dd5151..2995588e9abca 100644 --- a/mlir/test/Dialect/Linalg/collapse-dim.mlir +++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir @@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg // CHECK-LABEL: func.func @linalg_copy( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> { -// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32> -// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32> -// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> -// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32> -// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32> -// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32> -// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64> +// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64> +// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64> +// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64> +// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64> +// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64> +// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64> +// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64> // CHECK: }