diff --git a/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp new file mode 100644 index 0000000000000..cf9170d7e7110 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp @@ -0,0 +1,66 @@ +// This tests that the symmetric transfer at the final suspend point could happen successfully. +// Based on https://github.com/llvm/llvm-project/pull/85271#issuecomment-2007554532 +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O2 -emit-llvm %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" + +struct Task { + struct promise_type { + struct FinalAwaiter { + bool await_ready() const noexcept { return false; } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + return h.promise().continuation; + } + void await_resume() noexcept {} + }; + Task get_return_object() noexcept { + return std::coroutine_handle::from_promise(*this); + } + std::suspend_always initial_suspend() noexcept { return {}; } + FinalAwaiter final_suspend() noexcept { return {}; } + void unhandled_exception() noexcept {} + void return_value(int x) noexcept { + _value = x; + } + std::coroutine_handle<> continuation; + int _value; + }; + + Task(std::coroutine_handle handle) : handle(handle), stuff(123) {} + + struct Awaiter { + std::coroutine_handle handle; + Awaiter(std::coroutine_handle handle) : handle(handle) {} + bool await_ready() const noexcept { return false; } + std::coroutine_handle await_suspend(std::coroutine_handle continuation) noexcept { + handle.promise().continuation = continuation; + return handle; + } + int await_resume() noexcept { + int ret = handle.promise()._value; + handle.destroy(); + return ret; + } + }; + + auto operator co_await() { + auto handle_ = handle; + handle = nullptr; + return Awaiter(handle_); + } + +private: + std::coroutine_handle handle; + int stuff; +}; + +Task task0() { + co_return 43; +} + +// CHECK-LABEL: define{{.*}} void @_Z5task0v.resume +// This checks we are still in the scope of the current function. +// CHECK-NOT: {{^}}} +// CHECK: musttail call fastcc void +// CHECK-NEXT: ret void diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 3f3d81474fafe..3a43b1edcaba3 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1198,22 +1198,6 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { assert(InitialInst->getModule()); const DataLayout &DL = InitialInst->getModule()->getDataLayout(); - auto GetFirstValidInstruction = [](Instruction *I) { - while (I) { - // BitCastInst wouldn't generate actual code so that we could skip it. - if (isa(I) || I->isDebugOrPseudoInst() || - I->isLifetimeStartOrEnd()) - I = I->getNextNode(); - else if (isInstructionTriviallyDead(I)) - // Duing we are in the middle of the transformation, we need to erase - // the dead instruction manually. - I = &*I->eraseFromParent(); - else - break; - } - return I; - }; - auto TryResolveConstant = [&ResolvedValues](Value *V) { auto It = ResolvedValues.find(V); if (It != ResolvedValues.end()) @@ -1222,8 +1206,9 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { }; Instruction *I = InitialInst; - while (I->isTerminator() || isa(I)) { + while (true) { if (isa(I)) { + assert(!cast(I)->getReturnValue()); ReplaceInstWithInst(InitialInst, I->clone()); return true; } @@ -1247,40 +1232,26 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { BasicBlock *Succ = BR->getSuccessor(SuccIndex); scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); - I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime()); - + I = Succ->getFirstNonPHIOrDbgOrLifetime(); continue; } - if (auto *CondCmp = dyn_cast(I)) { + if (auto *Cmp = dyn_cast(I)) { // If the case number of suspended switch instruction is reduced to // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. - auto *BR = dyn_cast( - GetFirstValidInstruction(CondCmp->getNextNode())); - if (!BR || !BR->isConditional() || CondCmp != BR->getCondition()) - return false; - - // And the comparsion looks like : %cond = icmp eq i8 %V, constant. - // So we try to resolve constant for the first operand only since the - // second operand should be literal constant by design. - ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0)); - auto *Cond1 = dyn_cast(CondCmp->getOperand(1)); - if (!Cond0 || !Cond1) - return false; - - // Both operands of the CmpInst are Constant. So that we could evaluate - // it immediately to get the destination. - auto *ConstResult = - dyn_cast_or_null(ConstantFoldCompareInstOperands( - CondCmp->getPredicate(), Cond0, Cond1, DL)); - if (!ConstResult) - return false; - - ResolvedValues[BR->getCondition()] = ConstResult; - - // Handle this branch in next iteration. - I = BR; - continue; + // Try to constant fold it. + ConstantInt *Cond0 = TryResolveConstant(Cmp->getOperand(0)); + ConstantInt *Cond1 = TryResolveConstant(Cmp->getOperand(1)); + if (Cond0 && Cond1) { + ConstantInt *Result = + dyn_cast_or_null(ConstantFoldCompareInstOperands( + Cmp->getPredicate(), Cond0, Cond1, DL)); + if (Result) { + ResolvedValues[Cmp] = Result; + I = I->getNextNode(); + continue; + } + } } if (auto *SI = dyn_cast(I)) { @@ -1288,13 +1259,21 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { if (!Cond) return false; - BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); - scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); - I = GetFirstValidInstruction(BB->getFirstNonPHIOrDbgOrLifetime()); + BasicBlock *Succ = SI->findCaseValue(Cond)->getCaseSuccessor(); + scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); + I = Succ->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + + if (I->isDebugOrPseudoInst() || I->isLifetimeStartOrEnd() || + wouldInstructionBeTriviallyDead(I)) { + // We can skip instructions without side effects. If their values are + // needed, we'll notice later, e.g. when hitting a conditional branch. + I = I->getNextNode(); continue; } - return false; + break; } return false; diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll index 2257d5aee473e..d0d5005587bda 100644 --- a/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll @@ -1,6 +1,6 @@ ; Tests that sinked lifetime markers wouldn't provent optimization ; to convert a resuming call to a musttail call. -; The difference between this and coro-split-musttail5.ll and coro-split-musttail5.ll +; The difference between this and coro-split-musttail5.ll and coro-split-musttail6.ll ; is that this contains dead instruction generated during the transformation, ; which makes the optimization harder. ; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s @@ -8,7 +8,7 @@ declare void @fakeresume1(ptr align 8) -define void @g() #0 { +define i64 @g() #0 { entry: %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) %alloc = call ptr @malloc(i64 16) #3 @@ -27,6 +27,11 @@ await.suspend: %save2 = call token @llvm.coro.save(ptr null) call fastcc void @fakeresume1(ptr align 8 null) %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false) + + ; These (non-trivially) dead instructions are in the way. + %gep = getelementptr inbounds i64, ptr %alloc.var, i32 0 + %foo = ptrtoint ptr %gep to i64 + switch i8 %suspend2, label %exit [ i8 0, label %await.ready i8 1, label %exit @@ -36,8 +41,9 @@ await.ready: call void @llvm.lifetime.end.p0(i64 1, ptr %alloc.var) br label %exit exit: + %result = phi i64 [0, %entry], [0, %entry], [%foo, %await.suspend], [%foo, %await.suspend], [%foo, %await.ready] call i1 @llvm.coro.end(ptr null, i1 false, token none) - ret void + ret i64 %result } ; Verify that in the resume part resume call is marked with musttail.