-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
release/18.x: [MLIR] [Transforms] Let transform.structured.convert_to_loops
return handles to loops (#83984)
#85942
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (llvmbot) ChangesRequested by: @lhunloh Full diff: https://github.com/llvm/llvm-project/pull/85942.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b139f1ef58b3a9..da7183dae75ffc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
For operations that implement the `TilingInterface`, and implement
the `generateScalarImplementation` method, lowers the operation to
- loops. This operation does not return any handles.
+ loops. The return handle points to all generated loops.
+ Fails if the payload ops cannot be lowered to loops.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs);
+ let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
- $target attr-dict `:` type($target)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::TilingInterface target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
+ $target attr-dict `:` functional-type(operands, results)
}];
}
-
//===----------------------------------------------------------------------===//
// DecomposeInterfaceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db361..905875ae43ce8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2083,15 +2083,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// ConvertToLoopsOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
- transform::TransformRewriter &rewriter, TilingInterface target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- rewriter.setInsertionPoint(target);
- FailureOr<SmallVector<scf::ForOp>> loops =
- scf::lowerToLoopsUsingSCFForOp(rewriter, target);
- if (failed(loops))
- return emitDefaultDefiniteFailure(target);
+DiagnosedSilenceableFailure
+transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> loops;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto tilingOp = dyn_cast<TilingInterface>(*target);
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "expected the payload to implement TilingInterface";
+ diag.attachNote(target->getLoc()) << "payload op";
+ return diag;
+ }
+ rewriter.setInsertionPoint(target);
+ FailureOr<SmallVector<scf::ForOp>> generatedLoops =
+ scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
+ if (failed(generatedLoops))
+ return emitDefaultDefiniteFailure(target);
+ for (scf::ForOp &loop : *generatedLoops) {
+ loops.push_back(loop.getOperation());
+ }
+ rewriter.eraseOp(target);
+ }
+ results.set(cast<OpResult>(getResult()), loops);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7969de0d456bb6..8cbee3cbb758b2 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %matmul : !transform.any_op
+ %0 = transform.structured.convert_to_loops %matmul
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -33,6 +34,58 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-NOT: linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+
+// -----
+
+func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
+ %arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%arg2 : memref<?x?xf32>)
+ linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
+ outs(%arg4 : memref<?xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %linalg_ops
+ : (!transform.any_op) -> (!transform.any_op)
+ %1:5 = transform.split_handle %0
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func @gemm
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
+// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
+// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
+// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
+// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
+// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
+// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
+// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
+// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
+// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
+// CHECK: memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
// -----
@@ -65,7 +118,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %generic : !transform.any_op
+ %0 = transform.structured.convert_to_loops %generic
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -110,7 +164,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %conv : !transform.any_op
+ %0 = transform.structured.convert_to_loops %conv
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -164,7 +219,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %pool : !transform.any_op
+ %0 = transform.structured.convert_to_loops %pool
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -215,7 +271,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%map = transform.structured.match ops{["linalg.map"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %map : !transform.any_op
+ %0 = transform.structured.convert_to_loops %map
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -247,7 +304,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %transpose : !transform.any_op
+ %0 = transform.structured.convert_to_loops %transpose
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -284,7 +342,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %reduce : !transform.any_op
+ %0 = transform.structured.convert_to_loops %reduce
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
@@ -321,7 +380,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
: (!transform.any_op) -> !transform.any_op
- transform.structured.convert_to_loops %broadcast : !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
|
(cherry picked from commit 0597644)
…n handles to loops (llvm#83984) This lets `transform.structured.convert_to_loops` return handles to the generated loops, making this transformation more useful to use for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops. Introduced in commit aa2a96a, with a note that they might move out of the `Linalg`-Dialect, but no reason given for the non-return of handles. As far as I can see, this transform always returns loops. (cherry picked from commit 47bc565)
Also humbly pinging @qedawkins as the reviewer for the second commit. :) |
Pinging @MaheshRavishankar as I don't have permission to add reviewers and it's been a week. I'm relatively new to contributing, so if I am doing something wrong or have to do something else, I'd appreciate a pointer. :) |
Sorry for the delay. I dont know very much about transform dialect. Adding folks who might know better. |
We thought that might bring this transformation more in line with other transforms (which give handles to loops). The main branch doesn't build on our server at the moment, but the release branch does. New Issue for the cherry-pick of only the bugfix: issue #87079 Please close this PR (I can't). |
Backport 0597644 47bc565
Requested by: @lhunloh