diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 5eefe2664d0a1..3d7fe7b0f093f 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -68,6 +68,32 @@ def ForallToForOp : Op]> { + let summary = "Converts scf.forall into a nest of scf.for operations"; + let description = [{ + Converts the `scf.forall` operation pointed to by the given handle into an + `scf.parallel` operation. + + The operand handle must be associated with exactly one payload operation. + + Loops with outputs are not supported. + + #### Return Modes + + Consumes the operand handle. Produces a silenceable failure if the operand + is not associated with a single `scf.forall` payload operation. + Returns a handle to the new `scf.parallel` operation. + Produces a silenceable failure if another number of resulting handles is + requested. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs Variadic:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; +} + def LoopOutlineOp : Op]> { diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h index 31c3d0eb629d2..fb8411418ff9a 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h @@ -62,6 +62,9 @@ std::unique_ptr createForLoopRangeFoldingPass(); /// Creates a pass that converts SCF forall loops to SCF for loops. std::unique_ptr createForallToForLoopPass(); +/// Creates a pass that converts SCF forall loops to SCF parallel loops. +std::unique_ptr createForallToParallelLoopPass(); + // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index a7aeb42d60c0e..9b29affb97c43 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> { let constructor = "mlir::createForallToForLoopPass()"; } +def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> { + let summary = "Convert SCF forall loops to SCF parallel loops"; + let constructor = "mlir::createForallToParallelLoopPass()"; +} + def SCFForToWhileLoop : Pass<"scf-for-to-while"> { let summary = "Convert SCF for loops to SCF while loops"; let constructor = "mlir::createForToWhileLoopPass()"; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index b063e6e775e63..186331738d64b 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -39,6 +39,11 @@ class WhileOp; LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl *results = nullptr); +/// Try converting scf.forall into an scf.parallel loop. +/// The conversion is only supported for forall operations with no results. +LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, + ParallelOp *result = nullptr); + /// Fuses all adjacent scf.parallel operations with identical bounds and step /// into one scf.parallel operations. Uses a naive aliasing and dependency /// analysis. diff --git a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt index 6217976159fbb..63c5199af9290 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow MLIRArithDialect MLIRControlFlowDialect MLIRSCFDialect + MLIRSCFTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 9eb8a289d7d65..16f1db44acc35 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" @@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, PatternRewriter &rewriter) const { - Location loc = forallOp.getLoc(); - if (!forallOp.getOutputs().empty()) - return rewriter.notifyMatchFailure( - forallOp, - "only fully bufferized scf.forall ops can be lowered to scf.parallel"); - - // Convert mixed bounds and steps to SSA values. - SmallVector lbs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedLowerBound()); - SmallVector ubs = getValueOrCreateConstantIndexOp( - rewriter, loc, forallOp.getMixedUpperBound()); - SmallVector steps = - getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); - - // Create empty scf.parallel op. - auto parallelOp = rewriter.create(loc, lbs, ubs, steps); - rewriter.eraseBlock(¶llelOp.getRegion().front()); - rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), - parallelOp.getRegion().begin()); - // Replace the terminator. - rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); - rewriter.replaceOpWithNewOp( - parallelOp.getRegion().front().getTerminator()); - - // Erase the scf.forall op. - rewriter.replaceOp(forallOp, parallelOp); - return success(); + return scf::forallToParallelLoop(rewriter, forallOp); } void mlir::populateSCFToControlFlowConversionPatterns( diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 69f83d8bd70da..30699ecdde0a2 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ForallToForOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto payload = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payload)) + return emitSilenceableError() << "expected a single payload op"; + + auto target = dyn_cast(*payload.begin()); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected the payload to be scf.forall"; + diag.attachNote((*payload.begin())->getLoc()) << "payload op"; + return diag; + } + + if (!target.getOutputs().empty()) { + return emitSilenceableError() + << "unsupported shared outputs (didn't bufferize?)"; + } + + if (getNumResults() != 1) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "op expects one result, given " + << getNumResults(); + diag.attachNote(target.getLoc()) << "payload op"; + return diag; + } + + scf::ParallelOp opResult; + if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "failed to convert forall into parallel"; + return diag; + } + + results.set(cast(getTransformed()[0]), {opResult}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // LoopOutlineOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index e7671c9cc28f8..d363ffe941fce 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForallToFor.cpp + ForallToParallel.cpp ForToWhile.cpp LoopCanonicalization.cpp LoopPipelining.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp new file mode 100644 index 0000000000000..1fc0331300379 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp @@ -0,0 +1,86 @@ +//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ForallOp's into SCF.ParallelOps's. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, + scf::ForallOp forallOp, + scf::ParallelOp *result) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(forallOp); + + Location loc = forallOp.getLoc(); + if (!forallOp.getOutputs().empty()) + return rewriter.notifyMatchFailure( + forallOp, + "only fully bufferized scf.forall ops can be lowered to scf.parallel"); + + // Convert mixed bounds and steps to SSA values. + SmallVector lbs = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedLowerBound()); + SmallVector ubs = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedUpperBound()); + SmallVector steps = + getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); + + // Create empty scf.parallel op. + auto parallelOp = rewriter.create(loc, lbs, ubs, steps); + rewriter.eraseBlock(¶llelOp.getRegion().front()); + rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), + parallelOp.getRegion().begin()); + // Replace the terminator. + rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); + rewriter.replaceOpWithNewOp( + parallelOp.getRegion().front().getTerminator()); + + // If the mapping attribute is present, propagate to the new parallelOp. + if (forallOp.getMapping()) + parallelOp->setAttr("mapping", *forallOp.getMapping()); + + // Erase the scf.forall op. + rewriter.replaceOp(forallOp, parallelOp); + + if (result) + *result = parallelOp; + + return success(); +} + +namespace { +struct ForallToParallelLoop final + : public impl::SCFForallToParallelLoopBase { + void runOnOperation() override { + Operation *parentOp = getOperation(); + IRRewriter rewriter(parentOp->getContext()); + + parentOp->walk([&](scf::ForallOp forallOp) { + if (failed(scf::forallToParallelLoop(rewriter, forallOp))) { + return signalPassFailure(); + } + }); + } +}; +} // namespace + +std::unique_ptr mlir::createForallToParallelLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir new file mode 100644 index 0000000000000..acde601d47259 --- /dev/null +++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @two_iters +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @two_iters(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> () + // CHECK: scf.reduce + return +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @repeated +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @repeated(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> () + // CHECK: scf.reduce + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) + // CHECK: func.call @callee(%[[IV3]], %[[IV4]]) + // CHECK: scf.reduce + return +} + +// ----- + +func.func private @callee(%i: index, %j: index, %k: index, %l: index) + +// CHECK-LABEL: @nested +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index +func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) { + // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) { + // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) { + // CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]]) + // CHECK: scf.reduce + // CHECK: } + // CHECK: scf.reduce + // CHECK: } + scf.forall (%i, %j) in (%ub1, %ub2) { + scf.forall (%k, %l) in (%ub3, %ub4) { + func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> () + } + } + return +} + +// ----- + +// CHECK-LABEL: @mapping_attr +func.func @mapping_attr() -> () { + // CHECK: scf.parallel + // CHECK: scf.reduce + // CHECK: {mapping = [#gpu.thread]} + + %num_threads = arith.constant 100 : index + + scf.forall (%thread_idx) in (%num_threads) { + scf.forall.in_parallel { + } + } {mapping = [#gpu.thread]} + return + +} diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir new file mode 100644 index 0000000000000..b64798e06a4d1 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @two_iters +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @two_iters(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> () + // CHECK: scf.reduce + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +func.func @repeated(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected a single payload op}} + transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +// expected-note @below {{payload op}} +func.func private @callee(%i: index, %j: index) + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected the payload to be scf.forall}} + transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +}