diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 960cbe9ea4f0f5..8ac0b4f9636aa4 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -29,6 +29,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -1197,6 +1198,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { DenseMap ResolvedValues; BasicBlock *UnconditionalSucc = nullptr; + assert(InitialInst->getModule()); + const DataLayout &DL = InitialInst->getModule()->getDataLayout(); + + auto TryResolveConstant = [&ResolvedValues](Value *V) { + auto It = ResolvedValues.find(V); + if (It != ResolvedValues.end()) + V = It->second; + return dyn_cast(V); + }; Instruction *I = InitialInst; while (I->isTerminator() || @@ -1213,47 +1223,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { } if (auto *BR = dyn_cast(I)) { if (BR->isUnconditional()) { - BasicBlock *BB = BR->getSuccessor(0); + BasicBlock *Succ = BR->getSuccessor(0); if (I == InitialInst) - UnconditionalSucc = BB; - scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); - I = BB->getFirstNonPHIOrDbgOrLifetime(); + UnconditionalSucc = Succ; + scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); + I = Succ->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + + BasicBlock *BB = BR->getParent(); + // Handle the case the condition of the conditional branch is constant. + // e.g., + // + // br i1 false, label %cleanup, label %CoroEnd + // + // It is possible during the transformation. We could continue the + // simplifying in this case. + if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) { + // Handle this branch in next iteration. + I = BB->getTerminator(); continue; } } else if (auto *CondCmp = dyn_cast(I)) { + // If the case number of suspended switch instruction is reduced to + // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. auto *BR = dyn_cast(I->getNextNode()); - if (BR && BR->isConditional() && CondCmp == BR->getCondition()) { - // If the case number of suspended switch instruction is reduced to - // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. - // And the comparsion looks like : %cond = icmp eq i8 %V, constant. - ConstantInt *CondConst = dyn_cast(CondCmp->getOperand(1)); - if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) { - Value *V = CondCmp->getOperand(0); - auto it = ResolvedValues.find(V); - if (it != ResolvedValues.end()) - V = it->second; - - if (ConstantInt *Cond0 = dyn_cast(V)) { - BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue()) - ? BR->getSuccessor(0) - : BR->getSuccessor(1); - scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); - I = BB->getFirstNonPHIOrDbgOrLifetime(); - continue; - } - } - } + if (!BR || !BR->isConditional() || CondCmp != BR->getCondition()) + return false; + + // And the comparsion looks like : %cond = icmp eq i8 %V, constant. + // So we try to resolve constant for the first operand only since the + // second operand should be literal constant by design. + ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0)); + auto *Cond1 = dyn_cast(CondCmp->getOperand(1)); + if (!Cond0 || !Cond1) + return false; + + // Both operands of the CmpInst are Constant. So that we could evaluate + // it immediately to get the destination. + auto *ConstResult = + dyn_cast_or_null(ConstantFoldCompareInstOperands( + CondCmp->getPredicate(), Cond0, Cond1, DL)); + if (!ConstResult) + return false; + + CondCmp->replaceAllUsesWith(ConstResult); + CondCmp->eraseFromParent(); + + // Handle this branch in next iteration. + I = BR; + continue; } else if (auto *SI = dyn_cast(I)) { - Value *V = SI->getCondition(); - auto it = ResolvedValues.find(V); - if (it != ResolvedValues.end()) - V = it->second; - if (ConstantInt *Cond = dyn_cast(V)) { - BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); - scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); - I = BB->getFirstNonPHIOrDbgOrLifetime(); - continue; - } + ConstantInt *Cond = TryResolveConstant(SI->getCondition()); + if (!Cond) + return false; + + BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; } return false; } diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll index 9fd80179962069..0d73d94a93f581 100644 --- a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll @@ -42,9 +42,9 @@ coro.end: ret void } -; FIXME: The fakerresume1 here should be musttail call. ; CHECK-LABEL: @f.resume( -; CHECK-NOT: musttail call fastcc void @fakeresume1( +; CHECK: musttail call fastcc void @fakeresume1( +; CHECK-NEXT: ret void declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1 declare i1 @llvm.coro.alloc(token) #2