Skip to content

Commit

Permalink
[Coroutines] Ignore instructions more aggressively in addMustTailToCo…
Browse files Browse the repository at this point in the history
…roResumes() (#85271)

The old code used isInstructionTriviallyDead() and removed instructions
when walking the path from a resume call to function return to check if
the call is in tail position.

However, since the code was walking forwards it was not able to get past
instructions such as:

  %gep = getelementptr inbounds i64, ptr %alloc.var, i32 0
  %foo = ptrtoint ptr %gep to i64

This patch instead ignores such instructions as long as their values are
not needed. This enables the code to emit tail calls in more situations.
  • Loading branch information
zmodem committed Mar 20, 2024
1 parent ada24ae commit 9d1cb18
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 53 deletions.
66 changes: 66 additions & 0 deletions clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// This tests that the symmetric transfer at the final suspend point could happen successfully.
// Based on https://github.com/llvm/llvm-project/pull/85271#issuecomment-2007554532
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O2 -emit-llvm %s -o - | FileCheck %s

#include "Inputs/coroutine.h"

struct Task {
struct promise_type {
struct FinalAwaiter {
bool await_ready() const noexcept { return false; }
template <typename PromiseType>
std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
return h.promise().continuation;
}
void await_resume() noexcept {}
};
Task get_return_object() noexcept {
return std::coroutine_handle<promise_type>::from_promise(*this);
}
std::suspend_always initial_suspend() noexcept { return {}; }
FinalAwaiter final_suspend() noexcept { return {}; }
void unhandled_exception() noexcept {}
void return_value(int x) noexcept {
_value = x;
}
std::coroutine_handle<> continuation;
int _value;
};

Task(std::coroutine_handle<promise_type> handle) : handle(handle), stuff(123) {}

struct Awaiter {
std::coroutine_handle<promise_type> handle;
Awaiter(std::coroutine_handle<promise_type> handle) : handle(handle) {}
bool await_ready() const noexcept { return false; }
std::coroutine_handle<void> await_suspend(std::coroutine_handle<void> continuation) noexcept {
handle.promise().continuation = continuation;
return handle;
}
int await_resume() noexcept {
int ret = handle.promise()._value;
handle.destroy();
return ret;
}
};

auto operator co_await() {
auto handle_ = handle;
handle = nullptr;
return Awaiter(handle_);
}

private:
std::coroutine_handle<promise_type> handle;
int stuff;
};

Task task0() {
co_return 43;
}

// CHECK-LABEL: define{{.*}} void @_Z5task0v.resume
// This checks we are still in the scope of the current function.
// CHECK-NOT: {{^}}}
// CHECK: musttail call fastcc void
// CHECK-NEXT: ret void
79 changes: 29 additions & 50 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,22 +1198,6 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
assert(InitialInst->getModule());
const DataLayout &DL = InitialInst->getModule()->getDataLayout();

auto GetFirstValidInstruction = [](Instruction *I) {
while (I) {
// BitCastInst wouldn't generate actual code so that we could skip it.
if (isa<BitCastInst>(I) || I->isDebugOrPseudoInst() ||
I->isLifetimeStartOrEnd())
I = I->getNextNode();
else if (isInstructionTriviallyDead(I))
// Duing we are in the middle of the transformation, we need to erase
// the dead instruction manually.
I = &*I->eraseFromParent();
else
break;
}
return I;
};

auto TryResolveConstant = [&ResolvedValues](Value *V) {
auto It = ResolvedValues.find(V);
if (It != ResolvedValues.end())
Expand All @@ -1222,8 +1206,9 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
};

Instruction *I = InitialInst;
while (I->isTerminator() || isa<CmpInst>(I)) {
while (true) {
if (isa<ReturnInst>(I)) {
assert(!cast<ReturnInst>(I)->getReturnValue());
ReplaceInstWithInst(InitialInst, I->clone());
return true;
}
Expand All @@ -1247,54 +1232,48 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {

BasicBlock *Succ = BR->getSuccessor(SuccIndex);
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime());

I = Succ->getFirstNonPHIOrDbgOrLifetime();
continue;
}

if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
if (auto *Cmp = dyn_cast<CmpInst>(I)) {
// If the case number of suspended switch instruction is reduced to
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
auto *BR = dyn_cast<BranchInst>(
GetFirstValidInstruction(CondCmp->getNextNode()));
if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
return false;

// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
// So we try to resolve constant for the first operand only since the
// second operand should be literal constant by design.
ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
if (!Cond0 || !Cond1)
return false;

// Both operands of the CmpInst are Constant. So that we could evaluate
// it immediately to get the destination.
auto *ConstResult =
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
CondCmp->getPredicate(), Cond0, Cond1, DL));
if (!ConstResult)
return false;

ResolvedValues[BR->getCondition()] = ConstResult;

// Handle this branch in next iteration.
I = BR;
continue;
// Try to constant fold it.
ConstantInt *Cond0 = TryResolveConstant(Cmp->getOperand(0));
ConstantInt *Cond1 = TryResolveConstant(Cmp->getOperand(1));
if (Cond0 && Cond1) {
ConstantInt *Result =
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
Cmp->getPredicate(), Cond0, Cond1, DL));
if (Result) {
ResolvedValues[Cmp] = Result;
I = I->getNextNode();
continue;
}
}
}

if (auto *SI = dyn_cast<SwitchInst>(I)) {
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
if (!Cond)
return false;

BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = GetFirstValidInstruction(BB->getFirstNonPHIOrDbgOrLifetime());
BasicBlock *Succ = SI->findCaseValue(Cond)->getCaseSuccessor();
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
I = Succ->getFirstNonPHIOrDbgOrLifetime();
continue;
}

if (I->isDebugOrPseudoInst() || I->isLifetimeStartOrEnd() ||
wouldInstructionBeTriviallyDead(I)) {
// We can skip instructions without side effects. If their values are
// needed, we'll notice later, e.g. when hitting a conditional branch.
I = I->getNextNode();
continue;
}

return false;
break;
}

return false;
Expand Down
12 changes: 9 additions & 3 deletions llvm/test/Transforms/Coroutines/coro-split-musttail7.ll
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
; Tests that sinked lifetime markers wouldn't provent optimization
; to convert a resuming call to a musttail call.
; The difference between this and coro-split-musttail5.ll and coro-split-musttail5.ll
; The difference between this and coro-split-musttail5.ll and coro-split-musttail6.ll
; is that this contains dead instruction generated during the transformation,
; which makes the optimization harder.
; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
; RUN: opt < %s -passes='pgo-instr-gen,cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s

declare void @fakeresume1(ptr align 8)

define void @g() #0 {
define i64 @g() #0 {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%alloc = call ptr @malloc(i64 16) #3
Expand All @@ -27,6 +27,11 @@ await.suspend:
%save2 = call token @llvm.coro.save(ptr null)
call fastcc void @fakeresume1(ptr align 8 null)
%suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)

; These (non-trivially) dead instructions are in the way.
%gep = getelementptr inbounds i64, ptr %alloc.var, i32 0
%foo = ptrtoint ptr %gep to i64

switch i8 %suspend2, label %exit [
i8 0, label %await.ready
i8 1, label %exit
Expand All @@ -36,8 +41,9 @@ await.ready:
call void @llvm.lifetime.end.p0(i64 1, ptr %alloc.var)
br label %exit
exit:
%result = phi i64 [0, %entry], [0, %entry], [%foo, %await.suspend], [%foo, %await.suspend], [%foo, %await.ready]
call i1 @llvm.coro.end(ptr null, i1 false, token none)
ret void
ret i64 %result
}

; Verify that in the resume part resume call is marked with musttail.
Expand Down

0 comments on commit 9d1cb18

Please sign in to comment.