Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release/18.x: [MLIR] [Transforms] Let transform.structured.convert_to_loops return handles to loops (#83984) #85942

Closed
wants to merge 2 commits into from

Conversation

llvmbot
Copy link
Collaborator

@llvmbot llvmbot commented Mar 20, 2024

Backport 0597644 47bc565

Requested by: @lhunloh

@llvmbot
Copy link
Collaborator Author

llvmbot commented Mar 20, 2024

@hanhanW @lhunloh What do you think about merging this PR to the release branch?

@llvmbot
Copy link
Collaborator Author

llvmbot commented Mar 20, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (llvmbot)

Changes

Backport 0597644 47bc565

Requested by: @lhunloh


Full diff: https://github.com/llvm/llvm-project/pull/85942.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+9-13)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+25-9)
  • (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+68-8)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b139f1ef58b3a9..da7183dae75ffc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
 def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
     For operations that implement the `TilingInterface`, and implement
     the `generateScalarImplementation` method, lowers the operation to
-    loops. This operation does not return any handles.
+    loops. The return handle points to all generated loops.
+    Fails if the payload ops cannot be lowered to loops.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs);
+  let results = (outs TransformHandleTypeInterface:$result);
 
   let assemblyFormat = [{
-    $target attr-dict `:` type($target)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::TilingInterface target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
+    $target attr-dict `:` functional-type(operands, results)
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // DecomposeInterfaceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db361..905875ae43ce8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2083,15 +2083,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
 // ConvertToLoopsOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
-    transform::TransformRewriter &rewriter, TilingInterface target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  rewriter.setInsertionPoint(target);
-  FailureOr<SmallVector<scf::ForOp>> loops =
-      scf::lowerToLoopsUsingSCFForOp(rewriter, target);
-  if (failed(loops))
-    return emitDefaultDefiniteFailure(target);
+DiagnosedSilenceableFailure
+transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
+                                   transform::TransformResults &results,
+                                   transform::TransformState &state) {
+  SmallVector<Operation *> loops;
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto tilingOp = dyn_cast<TilingInterface>(*target);
+    if (!target) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
+          << "expected the payload to implement TilingInterface";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+    rewriter.setInsertionPoint(target);
+    FailureOr<SmallVector<scf::ForOp>> generatedLoops =
+        scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
+    if (failed(generatedLoops))
+      return emitDefaultDefiniteFailure(target);
+    for (scf::ForOp &loop : *generatedLoops) {
+      loops.push_back(loop.getOperation());
+    }
+    rewriter.eraseOp(target);
+  }
+  results.set(cast<OpResult>(getResult()), loops);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7969de0d456bb6..8cbee3cbb758b2 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %matmul : !transform.any_op
+    %0 = transform.structured.convert_to_loops %matmul
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -33,6 +34,58 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
 //       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
 //       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+//   CHECK-NOT:   linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+
+// -----
+
+func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+  %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
+  linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+  linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
+      outs(%arg4 : memref<?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %linalg_ops
+      : (!transform.any_op) -> (!transform.any_op)
+    %1:5 = transform.split_handle %0
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @gemm
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+//       CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+//       CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
+//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
+//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+//       CHECK:         memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+//       CHECK:   scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+//       CHECK:     scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+//   CHECK-DAG:         %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
+//   CHECK-DAG:         %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
+//   CHECK-DAG:         %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
+//       CHECK:         %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+//       CHECK:         %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+//       CHECK:         memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
 
 // -----
 
@@ -65,7 +118,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %generic : !transform.any_op
+    %0 = transform.structured.convert_to_loops %generic
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -110,7 +164,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %conv : !transform.any_op
+    %0 = transform.structured.convert_to_loops %conv
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -164,7 +219,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %pool : !transform.any_op
+    %0 = transform.structured.convert_to_loops %pool
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -215,7 +271,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %map = transform.structured.match ops{["linalg.map"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %map : !transform.any_op
+    %0 = transform.structured.convert_to_loops %map
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -247,7 +304,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %transpose : !transform.any_op
+    %0 = transform.structured.convert_to_loops %transpose
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -284,7 +342,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %reduce : !transform.any_op
+    %0 = transform.structured.convert_to_loops %reduce
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }
@@ -321,7 +380,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
     %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
-    transform.structured.convert_to_loops %broadcast : !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
     transform.yield
   }
 }

HerrCai0907 and others added 2 commits March 20, 2024 15:06
…n handles to loops (llvm#83984)

This lets `transform.structured.convert_to_loops` return handles to the
generated loops, making this transformation more useful to use for
(transformation-)nesting purposes. This is modelled after SCFs
`transform.loop.forall_to_for` which returns handles to loops.

Introduced in commit aa2a96a, with a
note that they might move out of the `Linalg`-Dialect, but no reason
given for the non-return of handles. As far as I can see, this transform
always returns loops.

(cherry picked from commit 47bc565)
@lhunloh
Copy link
Contributor

lhunloh commented Mar 20, 2024

Also humbly pinging @qedawkins as the reviewer for the second commit. :)

@qedawkins qedawkins self-requested a review March 20, 2024 16:14
@lhunloh
Copy link
Contributor

lhunloh commented Mar 27, 2024

Pinging @MaheshRavishankar as I don't have permission to add reviewers and it's been a week. I'm relatively new to contributing, so if I am doing something wrong or have to do something else, I'd appreciate a pointer. :)

@MaheshRavishankar
Copy link
Contributor

Pinging @MaheshRavishankar as I don't have permission to add reviewers and it's been a week. I'm relatively new to contributing, so if I am doing something wrong or have to do something else, I'd appreciate a pointer. :)

Sorry for the delay. I dont know very much about transform dialect. Adding folks who might know better.

@ftynse
Copy link
Member

ftynse commented Mar 29, 2024

0597644 looks like a bugfix, but 47bc565 is a arguably a new feature and likely should not be backported. What is the reason for backporting the latter?

@lhunloh
Copy link
Contributor

lhunloh commented Mar 29, 2024

We thought that might bring this transformation more in line with other transforms (which give handles to loops). The main branch doesn't build on our server at the moment, but the release branch does.

New Issue for the cherry-pick of only the bugfix: issue #87079

Please close this PR (I can't).
Sorry for the inconvenience, I'm still learning. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

None yet

6 participants