diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3f016fed3519c..33667e7ab0c5c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -811,19 +811,35 @@ validateDynamicDimExpansion(LinalgOp linalgOp, } // Create an expanded transpose op. -static Operation * -createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp, - SmallVector reassociation, - Value expandedInput, Value output) { - applyPermutationToVector(reassociation, transposeOp.getPermutation()); +// the reassociation map is already permuted hence we inverse permute and then +// flatten it. Then we inverse permute it again to get the final expanded +// transpose permutation. For example, +// +// permutation = [2, 0, 1] +// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]] +// +// inverse permutation = [1, 2, 0] +// applied to reassocation_map and then flattened becomes +// flatened permutation = [2, 3, 4, 5, 0, 1] +// final permuation is the inverse of the flattened permutation. +// +// Becomes +// +// permutation=[4, 5, 0, 1, 2, 3] + +static Operation *createExpandedTransposeOp(PatternRewriter &rewriter, + TransposeOp transposeOp, + Value expandedInput, Value output, + ExpansionInfo &expansionInfo) { SmallVector newPerm; - for (const auto &reassoc : reassociation) { - for (auto dim : reassoc) { + for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) { + auto reassoc = expansionInfo.getExpandedDims(perm); + for (int64_t dim : reassoc) { newPerm.push_back(dim); } } return rewriter.create(transposeOp.getLoc(), expandedInput, - output, newPerm); + output, invertPermutationVector(newPerm)); } // Create an expanded generic op. @@ -857,16 +873,18 @@ static Operation *createExpandedGenericOp( // Create an expanded fused op that retains the name for certain ops // such as fill, copy and transpose and produce a generic op for // rest of linalg ops. -static Operation *createExpandedOp( - PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes, - ArrayRef expandedOpOperands, ArrayRef outputs, - ArrayRef expandedOpIndexingMaps, ExpansionInfo &expansionInfo, - SmallVector reassociation) { +static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp, + TypeRange resultTypes, + ArrayRef expandedOpOperands, + ArrayRef outputs, + ArrayRef expandedOpIndexingMaps, + ExpansionInfo &expansionInfo) { return TypeSwitch(linalgOp.getOperation()) .Case([&](TransposeOp transposeOp) { - return createExpandedTransposeOp(rewriter, transposeOp, reassociation, - expandedOpOperands[0], outputs[0]); + return createExpandedTransposeOp(rewriter, transposeOp, + expandedOpOperands[0], outputs[0], + expansionInfo); }) .Case([&](Operation *op) { return clone(rewriter, linalgOp, resultTypes, @@ -986,12 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, } TypeRange resultTypes = ValueRange(outputs).getTypes(); - SmallVector reassociationBeforeExpansion = - isExpanding ? expandingReshapeOp.getReassociationIndices() - : collapsingReshapeOp.getReassociationIndices(); - Operation *fusedOp = createExpandedOp( - rewriter, linalgOp, resultTypes, expandedOpOperands, outputs, - expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion); + Operation *fusedOp = + createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands, + outputs, expandedOpIndexingMaps, expansionInfo); // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. SmallVector resultVals; diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index c8720ebd98c09..3244418d445b7 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -195,7 +195,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // CHECK-SAME: : tensor<8x33x4xf32> // CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T2:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel"] @@ -203,6 +203,29 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>) // CHECK: return %[[T2]] : tensor<8x33x4xf32> +// ----- + +func.func @reshape_as_consumer_transpose + (%a : tensor<4x210x6xf32>) + -> tensor<2x3x4x5x6x7xf32> { + %b = tensor.empty() : tensor<6x4x210xf32> + %c = linalg.transpose + ins(%a : tensor<4x210x6xf32>) + outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1] + %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32> + return %d : tensor<2x3x4x5x6x7xf32> +} +// CHECK: func @reshape_as_consumer_transpose +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32> +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() +// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32> +// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32 +// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>) +// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>) +// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3] +// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32> + + // ----- #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> @@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor, // ----- -func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] : - tensor into tensor - %1 = linalg.transpose ins(%0 : tensor) - outs(%arg1 : tensor) permutation = [1, 0] - return %1 : tensor + +func.func @reshape_as_producer_transpose + (%a : tensor<4x5x6x7x2x3xf32>) + -> tensor<6x4x210xf32> { + %b = tensor.empty() : tensor<6x4x210xf32> + %c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] : + tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32> + %d = linalg.transpose + ins(%c : tensor<4x210x6xf32>) + outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1] + return %d : tensor<6x4x210xf32> } -// CHECK: func @linalg_transpose_reshape_producer_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index -// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index -// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor into tensor -// CHECK: %[[T2:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs(%[[T1]] : tensor) -// CHECK-SAME: permutation = [2, 3, 0, 1] -// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]] -// CHECK-SAME: [0, 1], [2, 3] -// CHECK-SAME: tensor into tensor -// CHECK: return %[[T3]] +// CHECK: func @reshape_as_producer_transpose +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32> +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() +// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32> +// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>) +// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>) +// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3] +// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32> +// CHECK: return %[[T2]] : tensor<6x4x210xf32> + // -----