Skip to content

[mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern #129128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2025

Conversation

nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram commented Feb 27, 2025

This PR preserve linalg Op types for certain named ops such as Fill, Copy and Transpose instead of fusion always resulting in a generic Op.

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

This PR preserve linalg Op types instead of fusion always resulting in a generic Op.


Full diff: https://github.com/llvm/llvm-project/pull/129128.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+37-11)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+38-11)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f4b6955823085..f64151db8e5a0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -927,17 +927,43 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
       iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
-  auto fusedOp =
-      rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
-                                 /*inputs=*/expandedOpOperands, outputs,
-                                 expandedOpIndexingMaps, iteratorTypes);
-  Region &fusedRegion = fusedOp->getRegion(0);
-  Region &originalRegion = linalgOp->getRegion(0);
-  rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
-
-  // Update the index accesses after the expansion.
-  updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
-
+  Operation *fusedOp;
+
+  TypeSwitch<Operation *>(linalgOp.getOperation())
+      .Case<GenericOp>([&](GenericOp op) {
+        fusedOp = rewriter.create<GenericOp>(
+            linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
+            expandedOpIndexingMaps, iteratorTypes);
+        Region &fusedRegion = fusedOp->getRegion(0);
+        Region &originalRegion = linalgOp->getRegion(0);
+        rewriter.cloneRegionBefore(originalRegion, fusedRegion,
+                                   fusedRegion.begin());
+
+        // Update the index accesses after the expansion.
+        updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
+                                      expansionInfo);
+      })
+      .Case<TransposeOp>([&](TransposeOp op) {
+        SmallVector<ReassociationIndices> reassociation =
+            isExpanding ? expandingReshapeOp.getReassociationIndices()
+                        : collapsingReshapeOp.getReassociationIndices();
+        applyPermutationToVector(reassociation, op.getPermutation());
+        SmallVector<int64_t> newPerm;
+        for (auto reassoc : reassociation) {
+          for (auto dim : reassoc) {
+            newPerm.push_back(dim);
+          }
+        }
+        fusedOp = rewriter.create<TransposeOp>(
+            linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
+      })
+      // All other expandable linalg ops that are not generic or transpose can
+      // be cloned with the expanded input and output operands.
+      .Default([&](Operation *op) {
+        fusedOp = clone(
+            rewriter, linalgOp, resultTypes,
+            llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
+      });
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..80cebab590f6f 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
   return %1 : tensor<?x?x4x5xf32>
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_consumer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
 //      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[T4:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
-// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+//      CHECK:   %[[T4:.+]] = linalg.add
 // CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
 // CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
 //      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
 
 // -----
 
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
 func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
                                               %arg1 : tensor<?x?xf32>,
                                               %arg2 : tensor<?x?xf32>) ->
@@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
   return %1 : tensor<?x?xf32>
 }
 
-//  CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //      CHECK: func @linalg_add_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -817,9 +810,7 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 //      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
 //      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
 //      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-//      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
-// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
+//      CHECK:   %[[T3:.+]] = linalg.add
 // CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
 // CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
 //      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
@@ -827,6 +818,42 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
 
+// -----
+
+func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+                                              %arg1 : tensor<?x?xf32>) ->
+                                              tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
+       outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+  return %1 : tensor<?x?xf32>
+}
+
+//      CHECK: func @linalg_transpose_reshape_producer_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//  CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//  CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//  CHECK-DAG:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
+//  CHECK-DAG:   %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+//      CHECK:   %[[T2:.+]] = linalg.transpose
+// CHECK-SAME:     ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<?x8x?x7xf32>)
+// CHECK-SAME:   permutation = [2, 3, 0, 1]
+//      CHECK:   %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T3]]
+
+
+
 // -----
 
 func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {

@nirvedhmeshram nirvedhmeshram force-pushed the reshape_fix branch 2 times, most recently from 3f5678b to 8e07bc5 Compare February 28, 2025 17:48
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
@nirvedhmeshram nirvedhmeshram merged commit 70b95d1 into llvm:main Mar 3, 2025
11 checks passed
nirvedhmeshram added a commit that referenced this pull request Mar 11, 2025
…130344)

During #129128 adding reshape
as consumer fusion handling of linalg.transpose was missed. This PR adds
that.
Also transpose reshape as producer fusion test is updated to static
sizes as that is more likely to catch any issues with the permutation
vector in the verifier if the shapes dont match up.

---------

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Mar 11, 2025
…er fusion (#130344)

During llvm/llvm-project#129128 adding reshape
as consumer fusion handling of linalg.transpose was missed. This PR adds
that.
Also transpose reshape as producer fusion test is updated to static
sizes as that is more likely to catch any issues with the permutation
vector in the verifier if the shapes dont match up.

---------

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
…sion pattern (llvm#129128)

This PR preserve linalg Op types for certain named ops such as Fill,
Copy and Transpose instead of fusion always resulting in a generic Op.

---------

Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants