Skip to content

Commit

Permalink
[Coro][WebAssembly] Add tail-call check for async lowering (#81481)
Browse files Browse the repository at this point in the history
This patch fixes a verifier error when async lowering is used for
WebAssembly target without tail-call feature. This missing check was
revealed by b1ac052, which removed
inlining of the musttail'ed call and it started leaving the invalid call
at the verification stage. Additionally, `TTI::supportsTailCallFor` did
not respect the concrete TTI's `supportsTailCalls` implementation, so it
always returned true even though `supportsTailCalls` returned false, so
this patch also fixes the wrong CRTP base class implementation.
  • Loading branch information
kateinoigakukun committed Feb 20, 2024
1 parent 8ac7c4f commit 8c5c4d9
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
8 changes: 4 additions & 4 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,6 @@ class TargetTransformInfoImplBase {

bool supportsTailCalls() const { return true; }

bool supportsTailCallFor(const CallBase *CB) const {
return supportsTailCalls();
}

bool enableAggressiveInterleaving(bool LoopHasReductions) const {
return false;
}
Expand Down Expand Up @@ -1427,6 +1423,10 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
I, Ops, TargetTransformInfo::TCK_SizeAndLatency);
return Cost >= TargetTransformInfo::TCC_Expensive;
}

bool supportsTailCallFor(const CallBase *CB) const {
return static_cast<const T *>(this)->supportsTailCalls();
}
};
} // namespace llvm

Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3064,7 +3064,7 @@ static void doRematerializations(
}

void coro::buildCoroutineFrame(
Function &F, Shape &Shape,
Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
Expand Down Expand Up @@ -3100,7 +3100,7 @@ void coro::buildCoroutineFrame(
SmallVector<Value *, 8> Args(AsyncEnd->args());
auto Arguments = ArrayRef<Value *>(Args).drop_front(3);
auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn,
Arguments, Builder);
TTI, Arguments, Builder);
splitAround(Call, "MustTailCall.Before.CoroEnd");
}
}
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H

#include "CoroInstr.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/IRBuilder.h"

namespace llvm {
Expand Down Expand Up @@ -272,9 +273,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {

bool defaultMaterializable(Instruction &V);
void buildCoroutineFrame(
Function &F, Shape &Shape,
Function &F, Shape &Shape, TargetTransformInfo &TTI,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
} // End namespace llvm
Expand Down
15 changes: 10 additions & 5 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
}

CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments,
IRBuilder<> &Builder) {
auto *FnTy = MustTailCallFn->getFunctionType();
Expand All @@ -1755,14 +1756,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
coerceArguments(Builder, FnTy, Arguments, CallArgs);

auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
TailCall->setTailCallKind(CallInst::TCK_MustTail);
// Skip targets which don't support tail call.
if (TTI.supportsTailCallFor(TailCall)) {
TailCall->setTailCallKind(CallInst::TCK_MustTail);
}
TailCall->setDebugLoc(Loc);
TailCall->setCallingConv(MustTailCallFn->getCallingConv());
return TailCall;
}

static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones) {
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
Expand Down Expand Up @@ -1837,7 +1842,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVector<Value *, 8> Args(Suspend->args());
auto FnArgs = ArrayRef<Value *>(Args).drop_front(
CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder);
coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, FnArgs, Builder);
Builder.CreateRetVoid();

// Replace the lvm.coro.async.resume intrisic call.
Expand Down Expand Up @@ -2010,7 +2015,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
return Shape;

simplifySuspendPoints(Shape);
buildCoroutineFrame(F, Shape, MaterializableCallback);
buildCoroutineFrame(F, Shape, TTI, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);

// If there are no suspend points, no split required, just remove
Expand All @@ -2023,7 +2028,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
splitAsyncCoroutine(F, Shape, Clones);
splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/Transforms/Coroutines/coro-async-notail-wasm.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; RUN: opt < %s -O0 -S -mtriple=wasm32-unknown-unknown | FileCheck %s
; REQUIRES: webassembly-registered-target

%swift.async_func_pointer = type <{ i32, i32 }>
@checkTu = global %swift.async_func_pointer <{ i32 ptrtoint (ptr @check to i32), i32 8 }>

define swiftcc void @check(ptr %0) {
entry:
%1 = call token @llvm.coro.id.async(i32 0, i32 0, i32 0, ptr @checkTu)
%2 = call ptr @llvm.coro.begin(token %1, ptr null)
%3 = call ptr @llvm.coro.async.resume()
store ptr %3, ptr %0, align 4
%4 = call { ptr, i32 } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0i32s(i32 0, ptr %3, ptr @__swift_async_resume_project_context, ptr @check.0, ptr null, ptr null)
ret void
}

declare swiftcc void @check.0()
declare { ptr, i32 } @llvm.coro.suspend.async.sl_p0i32s(i32, ptr, ptr, ...)
declare token @llvm.coro.id.async(i32, i32, i32, ptr)
declare ptr @llvm.coro.begin(token, ptr writeonly)
declare ptr @llvm.coro.async.resume()

define ptr @__swift_async_resume_project_context(ptr %0) {
entry:
ret ptr null
}

; Verify that the resume call is not marked as musttail.
; CHECK-LABEL: define swiftcc void @check(
; CHECK-NOT: musttail call swiftcc void @check.0()
; CHECK: call swiftcc void @check.0()

0 comments on commit 8c5c4d9

Please sign in to comment.