diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index 012aa5dbb9ca00..fd80d0882b4b9f 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -365,6 +365,65 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, } } +// Check if duplicating the header to the preheader will fold all duplicated +// instructions, terminating with an unconditional jump. So No code duplication +// would be necessary. +static bool canRotateWithoutDuplication(Loop *L, llvm::LoopInfo *LI, + llvm::SimplifyQuery SQ) { + BasicBlock *OrigHeader = L->getHeader(); + BasicBlock *OrigPreheader = L->getLoopPreheader(); + + BranchInst *T = dyn_cast(OrigHeader->getTerminator()); + assert(T && T->isConditional() && "Header Terminator should be conditional!"); + + // a value map to holds the values incoming from preheader + ValueToValueMapTy ValueMap; + + // iterate over non-debug instructions and check which ones fold + for (Instruction &Inst : OrigHeader->instructionsWithoutDebug()) { + // For PHI nodes, the values in the cloned header are the values in the + // preheader + if (PHINode *PN = dyn_cast(&Inst)) { + InsertNewValueIntoMap(ValueMap, PN, + PN->getIncomingValueForBlock(OrigPreheader)); + continue; + } + if (Inst.mayHaveSideEffects()) + return false; + + Instruction *C = Inst.clone(); + C->insertBefore(&Inst); + + // Eagerly remap the operands of the instruction. + RemapInstruction(C, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // if we hit the terminator, return whether it folds + if (static_cast(&Inst) == T) { + if (isa(dyn_cast(C)->getCondition())) { + LLVM_DEBUG(dbgs() << "Loop Rotation: bypassing header threshold " + "because the header folds\n"); + C->eraseFromParent(); + return true; + } else { + LLVM_DEBUG( + dbgs() << "Loop Rotation: skipping. Header threshold exceeded " + "and header doesn't fold\n"); + C->eraseFromParent(); + return false; + } + } + + // other instructions, evaluate if possible + Value *V = simplifyInstruction(C, SQ); + if (V) + InsertNewValueIntoMap(ValueMap, &Inst, V); + C->eraseFromParent(); + } + + llvm_unreachable("OrigHeader didn't contain terminal branch!"); +} + /// Rotate loop LP. Return true if the loop is rotated. /// /// \param SimplifiedLatch is true if the latch was just folded into the final @@ -437,7 +496,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { L->dump()); return Rotated; } - if (Metrics.NumInsts > MaxHeaderSize) { + if (Metrics.NumInsts > MaxHeaderSize && + !canRotateWithoutDuplication(L, LI, SQ)) { LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains " << Metrics.NumInsts << " instructions, which is more than the threshold (" diff --git a/llvm/test/Transforms/LoopRotate/rotate-oz-if-branch-fold.ll b/llvm/test/Transforms/LoopRotate/rotate-oz-if-branch-fold.ll new file mode 100644 index 00000000000000..68d65a21c71340 --- /dev/null +++ b/llvm/test/Transforms/LoopRotate/rotate-oz-if-branch-fold.ll @@ -0,0 +1,39 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt < %s -S -Oz -debug -debug-only=loop-rotate 2>&1 | FileCheck %s -check-prefix=OZ + +; Loop should be rotated for -Oz if the duplicated branch in the header +; can be folded. +; Test adapted from gcc-c-torture/execute/pr90949.c + +; OZ: bypassing header threshold +; OZ: rotating Loop at depth 1 containing: %tailrecurse.i
,%tailrecurse.backedge.i +; OZ: into Loop at depth 1 containing: %tailrecurse.backedge.i
+ +%struct.Node = type { %struct.Node* } + +@space = dso_local global [2 x %struct.Node] zeroinitializer, align 8 + +define dso_local signext i32 @main() local_unnamed_addr #2 { +; OZ-LABEL: define dso_local signext i32 @main( +; OZ-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] { +; OZ-NEXT: entry: +; OZ-NEXT: store ptr null, ptr @space, align 8 +; OZ-NEXT: ret i32 0 +; +entry: + store %struct.Node* null, %struct.Node** getelementptr inbounds ([2 x %struct.Node], [2 x %struct.Node]* @space, i64 0, i64 0, i32 0), align 8 + br label %tailrecurse.i + +tailrecurse.i: ; preds = %tailrecurse.backedge.i, %entry + %module.tr.i = phi %struct.Node* [ getelementptr inbounds ([2 x %struct.Node], [2 x %struct.Node]* @space, i64 0, i64 0), %entry ], [ %module.tr.be.i, %tailrecurse.backedge.i ] + %cmp.i = icmp eq %struct.Node* %module.tr.i, null + br i1 %cmp.i, label %walk.exit, label %tailrecurse.backedge.i + +tailrecurse.backedge.i: ; preds = %tailrecurse.i + %module.tr.be.in.i = getelementptr inbounds %struct.Node, %struct.Node* %module.tr.i, i64 0, i32 0 + %module.tr.be.i = load %struct.Node*, %struct.Node** %module.tr.be.in.i, align 8 + br label %tailrecurse.i + +walk.exit: ; preds = %tailrecurse.i + ret i32 0 +}