Skip to content

Commit

Permalink
[WebAssembly] Make Emscripten EH work with Emscripten SjLj
Browse files Browse the repository at this point in the history
When Emscripten EH mixes with Emscripten SjLj, we are not currently
handling some of them correctly. There are three cases:
1. The current function calls `setjmp` and there is an `invoke` to a
   function that can either throw or longjmp. In this case, we have to
   check both for exception and longjmp. We are currently handling this
   case correctly:
   https://github.com/llvm/llvm-project/blob/0c0eb76782d5224b8d81a5afbb9a152bcf7c94c7/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp#L1058-L1090
   When inserting routines for functions that can longjmp, which we do
   only for setjmp-calling functions, we check if the function was
   previously an `invoke` and handle it correctly.

2. The current function does NOT call `setjmp` and there is an `invoke`
   to a function that can either throw or longjmp. Because there is no
   `setjmp` call, we haven't been doing any check for functions that can
   longjmp. But in that case, for `invoke`, we only check for an
   exception and if it is not an exception we reset `__THREW__` to 0,
   which can silently swallow the longjmp:
   https://github.com/llvm/llvm-project/blob/0c0eb76782d5224b8d81a5afbb9a152bcf7c94c7/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp#L70-L80
   This CL fixes this.

3. The current function calls `setjmp` and there is no `invoke`. Because
   it is not an `invoke`, we haven't been doing any check for functions
   that can throw, and only insert longjmp-checking routines for
   functions that can longjmp. But in that case, if a longjmpable
   function throws, we only check for a longjmp so if it is not a
   longjmp we reset `__THREW__` to 0, which can silently swallow the
   exception:
   https://github.com/llvm/llvm-project/blob/0c0eb76782d5224b8d81a5afbb9a152bcf7c94c7/llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp#L156-L169
   This CL fixes this.

To do that, this moves around some code, so we register necessary
functions for both EH and SjLj and precompute some data (the set of
functions that contains `setjmp`) before doing actual EH or SjLj
transformation.

This CL makes 2nd and 3rd tests in
emscripten-core/emscripten#14732 work.

Reviewed By: dschuff

Differential Revision: https://reviews.llvm.org/D106525
  • Loading branch information
aheejin authored and memfrob committed Oct 4, 2022
1 parent f7ce387 commit 509475b
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 63 deletions.
3 changes: 2 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssembly.h
Expand Up @@ -25,7 +25,8 @@ class ModulePass;
class FunctionPass;

// LLVM IR passes.
ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool DoEH, bool DoSjLj);
ModulePass *createWebAssemblyLowerEmscriptenEHSjLj(bool EnableEH,
bool EnableSjLj);
ModulePass *createWebAssemblyLowerGlobalDtors();
ModulePass *createWebAssemblyAddMissingPrototypes();
ModulePass *createWebAssemblyFixFunctionBitcasts();
Expand Down
134 changes: 112 additions & 22 deletions llvm/lib/Target/WebAssembly/WebAssemblyLowerEmscriptenEHSjLj.cpp
Expand Up @@ -216,6 +216,7 @@ namespace {
class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
bool EnableEH; // Enable exception handling
bool EnableSjLj; // Enable setjmp/longjmp handling
bool DoSjLj; // Whether we actually perform setjmp/longjmp handling

GlobalVariable *ThrewGV = nullptr;
GlobalVariable *ThrewValueGV = nullptr;
Expand All @@ -234,6 +235,8 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
StringMap<Function *> InvokeWrappers;
// Set of allowed function names for exception handling
std::set<std::string> EHAllowlistSet;
// Functions that contains calls to setjmp
SmallPtrSet<Function *, 8> SetjmpUsers;

StringRef getPassName() const override {
return "WebAssembly Lower Emscripten Exceptions";
Expand All @@ -252,6 +255,10 @@ class WebAssemblyLowerEmscriptenEHSjLj final : public ModulePass {
bool areAllExceptionsAllowed() const { return EHAllowlistSet.empty(); }
bool canLongjmp(Module &M, const Value *Callee) const;
bool isEmAsmCall(Module &M, const Value *Callee) const;
bool supportsException(const Function *F) const {
return EnableEH && (areAllExceptionsAllowed() ||
EHAllowlistSet.count(std::string(F->getName())));
}

void rebuildSSA(Function &F);

Expand Down Expand Up @@ -287,7 +294,7 @@ static bool canThrow(const Value *V) {
return false;
StringRef Name = F->getName();
// leave setjmp and longjmp (mostly) alone, we process them properly later
if (Name == "setjmp" || Name == "longjmp")
if (Name == "setjmp" || Name == "longjmp" || Name == "emscripten_longjmp")
return false;
return !F->doesNotThrow();
}
Expand Down Expand Up @@ -693,7 +700,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
Function *LongjmpF = M.getFunction("longjmp");
bool SetjmpUsed = SetjmpF && !SetjmpF->use_empty();
bool LongjmpUsed = LongjmpF && !LongjmpF->use_empty();
bool DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed);
DoSjLj = EnableSjLj && (SetjmpUsed || LongjmpUsed);

auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
assert(TPC && "Expected a TargetPassConfig");
Expand All @@ -718,7 +725,7 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {

bool Changed = false;

// Exception handling
// Function registration for exception handling
if (EnableEH) {
// Register __resumeException function
FunctionType *ResumeFTy =
Expand All @@ -729,26 +736,15 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
FunctionType *EHTypeIDTy =
FunctionType::get(IRB.getInt32Ty(), IRB.getInt8PtrTy(), false);
EHTypeIDF = getEmscriptenFunction(EHTypeIDTy, "llvm_eh_typeid_for", &M);

for (Function &F : M) {
if (F.isDeclaration())
continue;
Changed |= runEHOnFunction(F);
}
}

// Setjmp/longjmp handling
// Function registration and data pre-gathering for setjmp/longjmp handling
if (DoSjLj) {
Changed = true; // We have setjmp or longjmp somewhere

// Register emscripten_longjmp function
FunctionType *FTy = FunctionType::get(
IRB.getVoidTy(), {getAddrIntType(&M), IRB.getInt32Ty()}, false);
EmLongjmpF = getEmscriptenFunction(FTy, "emscripten_longjmp", &M);

if (LongjmpF)
replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF);

if (SetjmpF) {
// Register saveSetjmp function
FunctionType *SetjmpFTy = SetjmpF->getFunctionType();
Expand All @@ -765,16 +761,33 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runOnModule(Module &M) {
false);
TestSetjmpF = getEmscriptenFunction(FTy, "testSetjmp", &M);

// Only traverse functions that uses setjmp in order not to insert
// unnecessary prep / cleanup code in every function
SmallPtrSet<Function *, 8> SetjmpUsers;
// Precompute setjmp users
for (User *U : SetjmpF->users()) {
auto *UI = cast<Instruction>(U);
SetjmpUsers.insert(UI->getFunction());
}
}
}

// Exception handling transformation
if (EnableEH) {
for (Function &F : M) {
if (F.isDeclaration())
continue;
Changed |= runEHOnFunction(F);
}
}

// Setjmp/longjmp handling transformation
if (DoSjLj) {
Changed = true; // We have setjmp or longjmp somewhere
if (LongjmpF)
replaceLongjmpWithEmscriptenLongjmp(LongjmpF, EmLongjmpF);
// Only traverse functions that uses setjmp in order not to insert
// unnecessary prep / cleanup code in every function
if (SetjmpF)
for (Function *F : SetjmpUsers)
runSjLjOnFunction(*F);
}
}

if (!Changed) {
Expand Down Expand Up @@ -802,8 +815,6 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) {
bool Changed = false;
SmallVector<Instruction *, 64> ToErase;
SmallPtrSet<LandingPadInst *, 32> LandingPads;
bool AllowExceptions = areAllExceptionsAllowed() ||
EHAllowlistSet.count(std::string(F.getName()));

for (BasicBlock &BB : F) {
auto *II = dyn_cast<InvokeInst>(BB.getTerminator());
Expand All @@ -813,12 +824,51 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runEHOnFunction(Function &F) {
LandingPads.insert(II->getLandingPadInst());
IRB.SetInsertPoint(II);

bool NeedInvoke = AllowExceptions && canThrow(II->getCalledOperand());
const Value *Callee = II->getCalledOperand();
bool NeedInvoke = supportsException(&F) && canThrow(Callee);
if (NeedInvoke) {
// Wrap invoke with invoke wrapper and generate preamble/postamble
Value *Threw = wrapInvoke(II);
ToErase.push_back(II);

// If setjmp/longjmp handling is enabled, the thrown value can be not an
// exception but a longjmp. If the current function contains calls to
// setjmp, it will be appropriately handled in runSjLjOnFunction. But even
// if the function does not contain setjmp calls, we shouldn't silently
// ignore longjmps; we should rethrow them so they can be correctly
// handled in somewhere up the call chain where setjmp is.
// __THREW__'s value is 0 when nothing happened, 1 when an exception is
// thrown, other values when longjmp is thrown.
//
// if (%__THREW__.val == 0 || %__THREW__.val == 1)
// goto %tail
// else
// goto %longjmp.rethrow
//
// longjmp.rethrow: ;; This is longjmp. Rethrow it
// %__threwValue.val = __threwValue
// emscripten_longjmp(%__THREW__.val, %__threwValue.val);
//
// tail: ;; Nothing happened or an exception is thrown
// ... Continue exception handling ...
if (DoSjLj && !SetjmpUsers.count(&F) && canLongjmp(M, Callee)) {
BasicBlock *Tail = BasicBlock::Create(C, "tail", &F);
BasicBlock *RethrowBB = BasicBlock::Create(C, "longjmp.rethrow", &F);
Value *CmpEqOne =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one");
Value *CmpEqZero =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 0), "cmp.eq.zero");
Value *Or = IRB.CreateOr(CmpEqZero, CmpEqOne, "or");
IRB.CreateCondBr(Or, Tail, RethrowBB);
IRB.SetInsertPoint(RethrowBB);
Value *ThrewValue = IRB.CreateLoad(IRB.getInt32Ty(), ThrewValueGV,
ThrewValueGV->getName() + ".val");
IRB.CreateCall(EmLongjmpF, {Threw, ThrewValue});

IRB.CreateUnreachable();
IRB.SetInsertPoint(Tail);
}

// Insert a branch based on __THREW__ variable
Value *Cmp = IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp");
IRB.CreateCondBr(Cmp, II->getUnwindDest(), II->getNormalDest());
Expand Down Expand Up @@ -1098,6 +1148,46 @@ bool WebAssemblyLowerEmscriptenEHSjLj::runSjLjOnFunction(Function &F) {
Threw = wrapInvoke(CI);
ToErase.push_back(CI);
Tail = SplitBlock(BB, CI->getNextNode());

// If exception handling is enabled, the thrown value can be not a
// longjmp but an exception, in which case we shouldn't silently ignore
// exceptions; we should rethrow them.
// __THREW__'s value is 0 when nothing happened, 1 when an exception is
// thrown, other values when longjmp is thrown.
//
// if (%__THREW__.val == 1)
// goto %eh.rethrow
// else
// goto %normal
//
// eh.rethrow: ;; Rethrow exception
// %exn = call @__cxa_find_matching_catch_2() ;; Retrieve thrown ptr
// __resumeException(%exn)
//
// normal:
// <-- Insertion point. Will insert sjlj handling code from here
// goto %tail
//
// tail:
// ...
if (supportsException(&F) && canThrow(Callee)) {
IRB.SetInsertPoint(CI);
// We will add a new conditional branch. So remove the branch created
// when we split the BB
ToErase.push_back(BB->getTerminator());
BasicBlock *NormalBB = BasicBlock::Create(C, "normal", &F);
BasicBlock *RethrowBB = BasicBlock::Create(C, "eh.rethrow", &F);
Value *CmpEqOne =
IRB.CreateICmpEQ(Threw, getAddrSizeInt(&M, 1), "cmp.eq.one");
IRB.CreateCondBr(CmpEqOne, RethrowBB, NormalBB);
IRB.SetInsertPoint(RethrowBB);
CallInst *Exn = IRB.CreateCall(getFindMatchingCatch(M, 0), {}, "exn");
IRB.CreateCall(ResumeF, {Exn});
IRB.CreateUnreachable();
IRB.SetInsertPoint(NormalBB);
IRB.CreateBr(Tail);
BB = NormalBB; // New insertion point to insert testSetjmp()
}
}

// We need to replace the terminator in Tail - SplitBlock makes BB go
Expand Down
132 changes: 132 additions & 0 deletions llvm/test/CodeGen/WebAssembly/lower-em-ehsjlj.ll
@@ -0,0 +1,132 @@
; RUN: opt < %s -wasm-lower-em-ehsjlj -S | FileCheck %s
; RUN: llc < %s

; Tests for cases when exception handling and setjmp/longjmp handling are mixed.

target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
target triple = "wasm32-unknown-unknown"

%struct.__jmp_buf_tag = type { [6 x i32], i32, [32 x i32] }

; There is a function call (@foo) that can either throw an exception or longjmp
; and there is also a setjmp call. When @foo throws, we have to check both for
; exception and longjmp and jump to exception or longjmp handling BB depending
; on the result.
define void @setjmp_longjmp_exception() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
; CHECK-LABEL: @setjmp_longjmp_exception
entry:
%buf = alloca [1 x %struct.__jmp_buf_tag], align 16
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
invoke void @foo()
to label %try.cont unwind label %lpad

; CHECK: entry.split:
; CHECK: %[[CMP0:.*]] = icmp ne i32 %__THREW__.val, 0
; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue
; CHECK-NEXT: %[[CMP1:.*]] = icmp ne i32 %__threwValue.val, 0
; CHECK-NEXT: %[[CMP:.*]] = and i1 %[[CMP0]], %[[CMP1]]
; CHECK-NEXT: br i1 %[[CMP]], label %if.then1, label %if.else1

; This is exception checking part. %if.else1 leads here
; CHECK: entry.split.split:
; CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %[[CMP]], label %lpad, label %try.cont

; longjmp checking part
; CHECK: if.then1:
; CHECK: call i32 @testSetjmp

lpad: ; preds = %entry
%0 = landingpad { i8*, i32 }
catch i8* null
%1 = extractvalue { i8*, i32 } %0, 0
%2 = extractvalue { i8*, i32 } %0, 1
%3 = call i8* @__cxa_begin_catch(i8* %1) #2
call void @__cxa_end_catch()
br label %try.cont

try.cont: ; preds = %entry, %lpad
ret void
}

; @foo can either throw an exception or longjmp. Because this function doesn't
; have any setjmp calls, we only handle exceptions in this function. But because
; sjlj is enabled, we check if the thrown value is longjmp and if so rethrow it
; by calling @emscripten_longjmp.
define void @rethrow_longjmp() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
; CHECK-LABEL: @rethrow_longjmp
entry:
invoke void @foo()
to label %try.cont unwind label %lpad
; CHECK: entry:
; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: %cmp.eq.zero = icmp eq i32 %__THREW__.val, 0
; CHECK-NEXT: %or = or i1 %cmp.eq.zero, %cmp.eq.one
; CHECK-NEXT: br i1 %or, label %tail, label %longjmp.rethrow

; CHECK: tail:
; CHECK-NEXT: %cmp = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %cmp, label %lpad, label %try.cont

; CHECK: longjmp.rethrow:
; CHECK-NEXT: %__threwValue.val = load i32, i32* @__threwValue, align 4
; CHECK-NEXT: call void @emscripten_longjmp(i32 %__THREW__.val, i32 %__threwValue.val)
; CHECK-NEXT: unreachable

lpad: ; preds = %entry
%0 = landingpad { i8*, i32 }
catch i8* null
%1 = extractvalue { i8*, i32 } %0, 0
%2 = extractvalue { i8*, i32 } %0, 1
%3 = call i8* @__cxa_begin_catch(i8* %1) #5
call void @__cxa_end_catch()
br label %try.cont

try.cont: ; preds = %entry, %lpad
ret void
}

; This function contains a setjmp call and no invoke, so we only handle longjmp
; here. But @foo can also throw an exception, so we check if an exception is
; thrown and if so rethrow it by calling @__resumeException.
define void @rethrow_exception() {
; CHECK-LABEL: @rethrow_exception
entry:
%buf = alloca [1 x %struct.__jmp_buf_tag], align 16
%arraydecay = getelementptr inbounds [1 x %struct.__jmp_buf_tag], [1 x %struct.__jmp_buf_tag]* %buf, i32 0, i32 0
%call = call i32 @setjmp(%struct.__jmp_buf_tag* %arraydecay) #0
%cmp = icmp ne i32 %call, 0
br i1 %cmp, label %return, label %if.end

if.end: ; preds = %entry
call void @foo()
br label %return

; CHECK: if.end:
; CHECK: %cmp.eq.one = icmp eq i32 %__THREW__.val, 1
; CHECK-NEXT: br i1 %cmp.eq.one, label %eh.rethrow, label %normal

; CHECK: normal:
; CHECK-NEXT: icmp ne i32 %__THREW__.val, 0

; CHECK: eh.rethrow:
; CHECK-NEXT: %exn = call i8* @__cxa_find_matching_catch_2()
; CHECK-NEXT: call void @__resumeException(i8* %exn)
; CHECK-NEXT: unreachable

return: ; preds = %entry, %if.end
ret void
}

declare void @foo()
; Function Attrs: returns_twice
declare i32 @setjmp(%struct.__jmp_buf_tag*)
; Function Attrs: noreturn
declare void @longjmp(%struct.__jmp_buf_tag*, i32)
declare i32 @__gxx_personality_v0(...)
declare i8* @__cxa_begin_catch(i8*)
declare void @__cxa_end_catch()

attributes #0 = { returns_twice }
attributes #1 = { noreturn }

0 comments on commit 509475b

Please sign in to comment.