Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 #97607

Merged
merged 40 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
5020e49
Add getters for multi dim loop variables in LoopLikeOpInterface
srcarroll Jun 5, 2024
50852d5
Refactor LoopFuseSiblingOp and support parallel fusion
srcarroll Jun 4, 2024
b73238a
add checkFusionStructuralLegality
srcarroll Jun 5, 2024
f5bbd13
replace isLoopWithIdenticalConfiguration with checkFusionStructuralLe…
srcarroll Jun 5, 2024
7d99581
address review comment
srcarroll Jun 5, 2024
a5fa3b3
Make return types optional and change names
srcarroll Jun 6, 2024
1babe68
change return type of getInductionVars to SmallVector<Value>
srcarroll Jun 6, 2024
009fd15
address maks's comments
srcarroll Jun 6, 2024
d34ad95
change interface method names again and revert steps operand change
srcarroll Jun 6, 2024
e0e5262
return option induction vars
srcarroll Jun 6, 2024
7115a6e
address review comments
srcarroll Jun 7, 2024
1d4a444
Merge branch 'main' into add-loop-like-interface-methods
srcarroll Jun 7, 2024
af6b030
Merge branch 'add-loop-like-interface-methods' into scf-parallel-loop…
srcarroll Jun 7, 2024
6336fdf
update after rebase
srcarroll Jun 7, 2024
aa15617
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 7, 2024
7dbe646
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 8, 2024
86406c3
refactor main parallel fusion logic from fuseIfLegal to util func
srcarroll Jun 9, 2024
694d589
remove unused functions
srcarroll Jun 9, 2024
67cb64f
refactor fuseIndependentSiblingForLoops to reuse replaceWithAdditiona…
srcarroll Jun 9, 2024
cc8599f
refactor fuseIndependentSiblingForallLoops to reuse replaceWithAdditi…
srcarroll Jun 9, 2024
48b1af9
wip
srcarroll Jun 10, 2024
7a51cb3
Decouple concrete loop type from `createFused` function
srcarroll Jun 17, 2024
3087326
Refactor ForallOp::replaceWithAdditionalYields
srcarroll Jun 17, 2024
bcf3d4a
revert unnecessary changes
srcarroll Jun 17, 2024
0cb3c4e
cleanup
srcarroll Jun 18, 2024
7e41a54
address some review comments
srcarroll Jun 21, 2024
cc95d75
move `createFused` to `LoopLikeInterface.h`
srcarroll Jun 24, 2024
3430a36
address more review comments
srcarroll Jun 26, 2024
8447c12
switch to function_ref
srcarroll Jun 27, 2024
fbd7b72
check optional values
srcarroll Jun 27, 2024
ffb73a7
replace equalIterationSpaces with checkFusionStructuredLegality
srcarroll Jun 27, 2024
a6d0588
check if isOpSibling in checkFusionStructuralLegality
srcarroll Jun 27, 2024
ff47980
remove extra dominance check
srcarroll Jun 27, 2024
c6847ec
address more review comments
srcarroll Jun 27, 2024
f50c6aa
add more lit tests for scf.parallel
srcarroll Jun 27, 2024
6dd68c1
check for equal loop types in checkFusionStructuralLegality
srcarroll Jun 27, 2024
99d821b
address more comments
srcarroll Jun 27, 2024
6825c15
Merge branch 'main' into scf-parallel-loop-fusion
srcarroll Jun 27, 2024
7f9c172
Fix bug in fusion refactor and add test
srcarroll Jul 3, 2024
4b4fd91
add comment
srcarroll Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def ForallOp : SCF_Op<"forall", [
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
"replaceWithAdditionalYields", "promoteIfSingleIteration",
"yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);

//===----------------------------------------------------------------------===//
// Fusion related helpers
//===----------------------------------------------------------------------===//

/// Check structural compatibility between two loops such as iteration space
/// and dominance.
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
LoopLikeOpInterface source,
Diagnostic &diag);

/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
Expand All @@ -203,6 +213,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);

/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
///
/// This function does not perform any legality checks and simply fuses the
/// loops. The caller is responsible for ensuring that the loops are legal to
/// fuse.
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
scf::ParallelOp source,
RewriterBase &rewriter);
} // namespace mlir

#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
20 changes: 20 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,24 @@ struct JamBlockGatherer {
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"

namespace mlir {
/// A function that rewrites `target`'s terminator as a teminator obtained by
/// fusing `source` into `target`.
using FuseTerminatorFn =
function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
LoopLikeOpInterface &target, IRMapping mapping)>;

/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
/// `replaceWithAdditionalYields` interface method to replace the loop with a
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
/// callback is repsonsible for updating the fused loop terminator.
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
LoopLikeOpInterface source,
RewriterBase &rewriter,
NewYieldValuesFn newYieldValuesFn,
FuseTerminatorFn fuseTerminatorFn);

} // namespace mlir

#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,44 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,

SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }

FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
RewriterBase &rewriter, ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop,
const NewYieldValuesFn &newYieldValuesFn) {
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(getOperation());
SmallVector<Value> inits(getOutputs());
llvm::append_range(inits, newInitOperands);
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
inits, getMapping(),
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});

// Move the loop body to the new op.
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments().take_front(
getBody()->getNumArguments()));

if (replaceInitOperandUsesInLoop) {
// Replace all uses of `newInitOperands` with the corresponding basic block
// arguments.
for (auto &&[newOperand, oldOperand] :
llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
newInitOperands.size()))) {
rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
Operation *user = use.getOwner();
return newLoop->isProperAncestor(user);
});
}
}

// Replace the old loop.
rewriter.replaceOp(getOperation(),
newLoop->getResults().take_front(getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}

/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down
140 changes: 21 additions & 119 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
return 1;
};

std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubConstant =
getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant =
getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
Expand Down Expand Up @@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//

/// Check if `target` and `source` are siblings, in the context that `target`
/// is being fused into `source`.
///
/// This is a simple check that just checks if both operations are in the same
/// block and some checks to ensure that the fused IR does not violate
/// dominance.
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
Operation *source) {
// Check if both operations are same.
if (target == source)
return emitSilenceableFailure(source)
<< "target and source need to be different loops";

// Check if both operations are in the same block.
if (target->getBlock() != source->getBlock())
return emitSilenceableFailure(source)
<< "target and source are not in the same block";

// Check if fusion will violate dominance.
DominanceInfo domInfo(source);
if (target->isBeforeInBlock(source)) {
// Since `target` is before `source`, all users of results of `target`
// need to be dominated by `source`.
for (Operation *user : target->getUsers()) {
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
return emitSilenceableFailure(target)
<< "user of results of target should be properly dominated by "
"source";
}
}
} else {
// Since `target` is after `source`, all values used by `target` need
// to dominate `source`.

// Check if operands of `target` are dominated by `source`.
for (Value operand : target->getOperands()) {
Operation *operandOp = operand.getDefiningOp();
// Operands without defining operations are block arguments. When `target`
// and `source` occur in the same block, these operands dominate `source`.
if (!operandOp)
continue;

// Operand's defining operation should properly dominate `source`.
if (!domInfo.properlyDominates(operandOp, source,
/*enclosingOpOk=*/false))
return emitSilenceableFailure(target)
<< "operands of target should be properly dominated by source";
}

// Check if values used by `target` are dominated by `source`.
bool failed = false;
OpOperand *failedValue = nullptr;
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
Operation *operandOp = operand->get().getDefiningOp();
if (operandOp && !domInfo.properlyDominates(operandOp, source,
/*enclosingOpOk=*/false)) {
// `operand` is not an argument of an enclosing block and the defining
// op of `operand` is outside `target` but does not dominate `source`.
failed = true;
failedValue = operand;
}
});

if (failed)
return emitSilenceableFailure(failedValue->getOwner())
<< "values used inside regions of target should be properly "
"dominated by source";
}

return DiagnosedSilenceableFailure::success();
}

/// Check if `target` scf.forall can be fused into `source` scf.forall.
///
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
static bool isForallWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForallOp>(target);
auto sourceOp = dyn_cast<scf::ForallOp>(source);
if (!targetOp || !sourceOp)
return false;

return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
targetOp.getMapping() == sourceOp.getMapping();
}

/// Check if `target` scf.for can be fused into `source` scf.for.
///
/// This simply checks if both loops have the same bounds and steps. No attempt
/// is made at checking that the side effects of `target` and `source` are
/// independent of each other.
static bool isForWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForOp>(target);
auto sourceOp = dyn_cast<scf::ForOp>(source);
if (!targetOp || !sourceOp)
return false;

return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
targetOp.getStep() == sourceOp.getStep();
}

DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
Expand All @@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}

Operation *target = *targetOps.begin();
Operation *source = *sourceOps.begin();
auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
if (!target || !source)
return emitSilenceableFailure(target->getLoc())
<< "target or source is not a loop op";

// Check if the target and source are siblings.
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
if (!diag.succeeded())
return diag;
// Check if loops can be fused
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
if (!mlir::checkFusionStructuralLegality(target, source, diag))
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));

Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
if (isForWithIdenticalConfiguration(target, source)) {
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall
// and scf.parallel.
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
} else if (isForallWithIdenticalConfiguration(target, source)) {
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
fusedLoop = fuseIndependentSiblingParallelLoops(
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
<< "unsupported loop type for fusion";

assert(fusedLoop && "failed to fuse operations");

Expand Down
80 changes: 6 additions & 74 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
Expand All @@ -37,24 +38,6 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
return walkResult.wasInterrupted();
}

/// Verify equal iteration spaces.
static bool equalIterationSpaces(ParallelOp firstPloop,
ParallelOp secondPloop) {
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
return false;

auto matchOperands = [&](const OperandRange &lhs,
const OperandRange &rhs) -> bool {
// TODO: Extend this to support aliases and equal constants.
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
};
return matchOperands(firstPloop.getLowerBound(),
secondPloop.getLowerBound()) &&
matchOperands(firstPloop.getUpperBound(),
secondPloop.getUpperBound()) &&
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}

/// Checks if the parallel loops have mixed access to the same buffers. Returns
/// `true` if the first parallel loop writes to the same indices that the second
/// loop reads.
Expand Down Expand Up @@ -153,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
equalIterationSpaces(firstPloop, secondPloop) &&
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices, mayAlias));
}
Expand All @@ -174,61 +158,9 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
mayAlias))
return;

DominanceInfo dom;
// We are fusing first loop into second, make sure there are no users of the
// first loop results between loops.
for (Operation *user : firstPloop->getUsers())
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
return;

ValueRange inits1 = firstPloop.getInitVals();
ValueRange inits2 = secondPloop.getInitVals();

SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
newInitVars.append(inits2.begin(), inits2.end());

IRRewriter b(builder);
b.setInsertionPoint(secondPloop);
auto newSecondPloop = b.create<ParallelOp>(
secondPloop.getLoc(), secondPloop.getLowerBound(),
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);

Block *newBlock = newSecondPloop.getBody();
auto term1 = cast<ReduceOp>(block1->getTerminator());
auto term2 = cast<ReduceOp>(block2->getTerminator());

b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
newBlock->getArguments());
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
newBlock->getArguments());

ValueRange results = newSecondPloop.getResults();
if (!results.empty()) {
b.setInsertionPointToEnd(newBlock);

ValueRange reduceArgs1 = term1.getOperands();
ValueRange reduceArgs2 = term2.getOperands();
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());

auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);

for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
term1.getReductions(), term2.getReductions()))) {
Block &oldRedBlock = reg.front();
Block &newRedBlock = newReduceOp.getReductions()[i].front();
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
newRedBlock.getArguments());
}

firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
}
term1->erase();
term2->erase();
firstPloop.erase();
secondPloop.erase();
secondPloop = newSecondPloop;
IRRewriter rewriter(builder);
secondPloop = mlir::fuseIndependentSiblingParallelLoops(
firstPloop, secondPloop, rewriter);
}

void mlir::scf::naivelyFuseParallelOps(
Expand Down
Loading