-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern #129128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesThis PR preserve linalg Op types instead of fusion always resulting in a generic Op. Full diff: https://github.com/llvm/llvm-project/pull/129128.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f4b6955823085..f64151db8e5a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -927,17 +927,43 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
iteratorTypes[j] = type;
TypeRange resultTypes = ValueRange(outputs).getTypes();
- auto fusedOp =
- rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
- /*inputs=*/expandedOpOperands, outputs,
- expandedOpIndexingMaps, iteratorTypes);
- Region &fusedRegion = fusedOp->getRegion(0);
- Region &originalRegion = linalgOp->getRegion(0);
- rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
-
- // Update the index accesses after the expansion.
- updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
-
+ Operation *fusedOp;
+
+ TypeSwitch<Operation *>(linalgOp.getOperation())
+ .Case<GenericOp>([&](GenericOp op) {
+ fusedOp = rewriter.create<GenericOp>(
+ linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
+ expandedOpIndexingMaps, iteratorTypes);
+ Region &fusedRegion = fusedOp->getRegion(0);
+ Region &originalRegion = linalgOp->getRegion(0);
+ rewriter.cloneRegionBefore(originalRegion, fusedRegion,
+ fusedRegion.begin());
+
+ // Update the index accesses after the expansion.
+ updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
+ expansionInfo);
+ })
+ .Case<TransposeOp>([&](TransposeOp op) {
+ SmallVector<ReassociationIndices> reassociation =
+ isExpanding ? expandingReshapeOp.getReassociationIndices()
+ : collapsingReshapeOp.getReassociationIndices();
+ applyPermutationToVector(reassociation, op.getPermutation());
+ SmallVector<int64_t> newPerm;
+ for (auto reassoc : reassociation) {
+ for (auto dim : reassoc) {
+ newPerm.push_back(dim);
+ }
+ }
+ fusedOp = rewriter.create<TransposeOp>(
+ linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
+ })
+ // All other expandable linalg ops that are not generic or transpose can
+ // be cloned with the expanded input and output operands.
+ .Default([&](Operation *op) {
+ fusedOp = clone(
+ rewriter, linalgOp, resultTypes,
+ llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
+ });
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..80cebab590f6f 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
return %1 : tensor<?x?x4x5xf32>
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @linalg_add_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[T4:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
-// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
+// CHECK: %[[T4:.+]] = linalg.add
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>)
// CHECK: return %[[T4]] : tensor<?x?x4x5xf32>
// -----
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) ->
@@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
return %1 : tensor<?x?xf32>
}
-// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @linalg_add_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -817,9 +810,7 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-// CHECK: %[[T3:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
-// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
+// CHECK: %[[T3:.+]] = linalg.add
// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>)
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
@@ -827,6 +818,42 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]
+// -----
+
+func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+ %arg1 : tensor<?x?xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+ tensor<?x7x?x8xf32> into tensor<?x?xf32>
+ %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
+ outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK: func @linalg_transpose_reshape_producer_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
+// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+// CHECK: %[[T2:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
+// CHECK-SAME: permutation = [2, 3, 0, 1]
+// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME: [0, 1], [2, 3]
+// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
+// CHECK: return %[[T3]]
+
+
+
// -----
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
|
MaheshRavishankar
requested changes
Feb 27, 2025
3f5678b
to
8e07bc5
Compare
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
8e07bc5
to
3526164
Compare
MaheshRavishankar
approved these changes
Mar 2, 2025
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
MaheshRavishankar
approved these changes
Mar 3, 2025
nirvedhmeshram
added a commit
that referenced
this pull request
Mar 11, 2025
…130344) During #129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that. Also transpose reshape as producer fusion test is updated to static sizes as that is more likely to catch any issues with the permutation vector in the verifier if the shapes dont match up. --------- Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
llvm-sync bot
pushed a commit
to arm/arm-toolchain
that referenced
this pull request
Mar 11, 2025
…er fusion (#130344) During llvm/llvm-project#129128 adding reshape as consumer fusion handling of linalg.transpose was missed. This PR adds that. Also transpose reshape as producer fusion test is updated to static sizes as that is more likely to catch any issues with the permutation vector in the verifier if the shapes dont match up. --------- Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
jph-13
pushed a commit
to jph-13/llvm-project
that referenced
this pull request
Mar 21, 2025
…sion pattern (llvm#129128) This PR preserve linalg Op types for certain named ops such as Fill, Copy and Transpose instead of fusion always resulting in a generic Op. --------- Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR preserve linalg Op types for certain named ops such as Fill, Copy and Transpose instead of fusion always resulting in a generic Op.