diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index ba648181daecb..830b49321c2e4 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val); bool insideMutuallyExclusiveBranches(Operation *a, Operation *b); /// Promotes the loop body of a scf::ForallOp to its containing block. -void promote(RewriterBase &rewriter, scf::ForallOp forallOp); +LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp); /// An owning vector of values, handy to return from functions. using ValueVector = SmallVector; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 2453cf5b5b5a4..4fb4cc8410230 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, DeclareOpInterfaceMethods, + ["getUpdatedDestinations", "getIteratingParent", + "promoteInParallelLoop", "canPromoteInParallelLoop"]>, // TODO: Cannot use an interface here atm, verify this manually for now. // HasParent<"InParallelOpInterface"> ]> { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h index 82ab427699f64..85cc18c47a527 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h @@ -15,6 +15,8 @@ #define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_ #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" namespace mlir { namespace detail { diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td index ace26f723ef53..1a333d82d8468 100644 --- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td +++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td @@ -106,6 +106,27 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> { /*methodName=*/"getIteratingParent", /*args=*/(ins) >, + InterfaceMethod< + /*desc=*/[{ + Promotes this parallel combining op out of its enclosing parallel loop + and returns the value that should replace the destination updated by + this op. + }], + /*retTy=*/"::mlir::FailureOr<::mlir::Value>", + /*methodName=*/"promoteInParallelLoop", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter) + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if this op can be promoted out of its enclosing parallel + loop. + }], + /*retTy=*/"bool", + /*methodName=*/"canPromoteInParallelLoop", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return false; }] + >, ]; } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index c35989ecba6cd..04737738d8593 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { return failure(); } - promote(rewriter, *this); - return success(); + return promote(rewriter, *this); } Block::BlockArgListType ForallOp::getRegionIterArgs() { @@ -664,10 +663,21 @@ MutableArrayRef ForallOp::getInitsMutable() { } /// Promotes the loop body of a scf::ForallOp to its containing block. -void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { +LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { OpBuilder::InsertionGuard g(rewriter); scf::InParallelOp terminator = forallOp.getTerminator(); + // Make sure we can promote all parallel combining ops in terminator: + for (auto &yieldingOp : terminator.getYieldingOps()) { + auto parallelCombiningOp = + dyn_cast(&yieldingOp); + if (!parallelCombiningOp) + continue; + if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter)) + return rewriter.notifyMatchFailure( + forallOp, "parallel combining op cannot be promoted"); + } + // Replace block arguments with lower bounds (replacements for IVs) and // outputs. SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter); @@ -683,30 +693,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { SmallVector results; results.reserve(forallOp.getResults().size()); for (auto &yieldingOp : terminator.getYieldingOps()) { - auto parallelInsertSliceOp = - dyn_cast(yieldingOp); - if (!parallelInsertSliceOp) + auto parallelCombiningOp = + dyn_cast(&yieldingOp); + if (!parallelCombiningOp) continue; - Value dst = parallelInsertSliceOp.getDest(); - Value src = parallelInsertSliceOp.getSource(); - if (llvm::isa(src.getType())) { - results.push_back(tensor::InsertSliceOp::create( - rewriter, forallOp.getLoc(), dst.getType(), src, dst, - parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), - parallelInsertSliceOp.getStrides(), - parallelInsertSliceOp.getStaticOffsets(), - parallelInsertSliceOp.getStaticSizes(), - parallelInsertSliceOp.getStaticStrides())); - } else { - llvm_unreachable("unsupported terminator"); - } + assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter)); + + FailureOr promotedValue = + parallelCombiningOp.promoteInParallelLoop(rewriter); + if (failed(promotedValue)) + return failure(); + + results.push_back(*promotedValue); } + if (results.size() != forallOp.getResults().size()) + return rewriter.notifyMatchFailure( + forallOp, "failed to materialize replacements for all results"); rewriter.replaceAllUsesWith(forallOp.getResults(), results); // Erase the old terminator and the loop. rewriter.eraseOp(terminator); rewriter.eraseOp(forallOp); + return success(); } LoopNest mlir::scf::buildLoopNest( @@ -1789,7 +1798,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder // All of the loop dimensions perform a single iteration. Inline loop body. if (newMixedLowerBounds.empty()) { - promote(rewriter, op); + if (failed(promote(rewriter, op))) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index fa97b49a41d97..f05c58a40fde0 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3947,6 +3947,25 @@ Operation *ParallelInsertSliceOp::getIteratingParent() { return nullptr; } +FailureOr +ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) { + Value dst = getDest(); + Value src = getSource(); + if (!isa(src.getType())) + return failure(); + + Value inserted = tensor::InsertSliceOp::create( + rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(), + getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides()); + + return inserted; +} + +bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) { + return isa(getSource().getType()) && + isa(getDest().getType()); +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===//