diff --git a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h index 1575afa50e198..5a682e8c7b5eb 100644 --- a/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h +++ b/llvm/include/llvm/Transforms/IPO/FunctionSpecialization.h @@ -246,7 +246,7 @@ class FunctionSpecializer { std::function GetAC; SmallPtrSet Specializations; - SmallPtrSet FullySpecialized; + SmallPtrSet DeadFunctions; DenseMap FunctionMetrics; DenseMap FunctionGrowth; unsigned NGlobals = 0; @@ -270,6 +270,8 @@ class FunctionSpecializer { return InstCostVisitor(GetBFI, F, M.getDataLayout(), TTI, Solver); } + bool isDeadFunction(Function *F) { return DeadFunctions.contains(F); } + private: Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index c876a47ef2129..a5fbba39ee283 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -839,14 +839,24 @@ bool FunctionSpecializer::run() { } void FunctionSpecializer::removeDeadFunctions() { - for (Function *F : FullySpecialized) { + for (Function *F : DeadFunctions) { LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " << F->getName() << "\n"); if (FAM) FAM->clear(*F, F->getName()); + + // Remove all the callsites that were proven unreachable once, and replace + // them with poison. + for (User *U : make_early_inc_range(F->users())) { + assert((isa(U) || isa(U)) && + "User of dead function must be call or invoke"); + Instruction *CS = cast(U); + CS->replaceAllUsesWith(PoisonValue::get(CS->getType())); + CS->eraseFromParent(); + } F->eraseFromParent(); } - FullySpecialized.clear(); + DeadFunctions.clear(); } /// Clone the function \p F and remove the ssa_copy intrinsics added by @@ -1207,8 +1217,11 @@ void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin, // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F)) { + // NOTE: If the address of a function is taken, we cannot treat it as dead + // function. + if (NCallsLeft == 0 && Solver.isArgumentTrackedFunction(F) && + !F->hasAddressTaken()) { Solver.markFunctionUnreachable(F); - FullySpecialized.insert(F); + DeadFunctions.insert(F); } } diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp index d50de34dfa482..e98a70f228ada 100644 --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -169,6 +169,10 @@ static bool runIPSCCP( for (Function &F : M) { if (F.isDeclaration()) continue; + // Skip the dead functions marked by FunctionSpecializer, avoiding removing + // blocks in dead functions. + if (IsFuncSpecEnabled && Specializer.isDeadFunction(&F)) + continue; SmallVector BlocksToErase; diff --git a/llvm/test/Transforms/FunctionSpecialization/reachable-after-specialization.ll b/llvm/test/Transforms/FunctionSpecialization/reachable-after-specialization.ll new file mode 100644 index 0000000000000..ab2260aadab31 --- /dev/null +++ b/llvm/test/Transforms/FunctionSpecialization/reachable-after-specialization.ll @@ -0,0 +1,135 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=ipsccp --funcspec-min-function-size=1 -S < %s | FileCheck %s + +define i32 @caller() { +; CHECK-LABEL: define i32 @caller() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[CALL1:%.*]] = call i32 @callee.specialized.1(i32 1) +; CHECK-NEXT: [[CALL2:%.*]] = call i32 @callee.specialized.2(i32 0) +; CHECK-NEXT: [[COND:%.*]] = icmp eq i32 undef, 0 +; CHECK-NEXT: br i1 [[COND]], label %[[COMMON_RET:.*]], label %[[IF_THEN:.*]] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret i32 0 +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: ret i32 0 +; +entry: + %call1 = call i32 @callee(i32 1) + %call2 = call i32 @callee(i32 0) + %cond = icmp eq i32 %call2, 0 + br i1 %cond, label %common.ret, label %if.then + +common.ret: ; preds = %entry + ret i32 0 + +if.then: ; preds = %entry + %unreachable_call = call i32 @callee(i32 2) + ret i32 %unreachable_call +} + +define internal i32 @callee(i32 %arg) { +entry: + br label %loop + +loop: ; preds = %ai, %entry + %add = or i32 0, 0 + %cond = icmp eq i32 %arg, 1 + br i1 %cond, label %exit, label %loop + +exit: ; preds = %ai + ret i32 0 +} + +declare void @other_user(ptr) + +define i32 @caller2() { +; CHECK-LABEL: define i32 @caller2() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: call void @other_user(ptr @callee2) +; CHECK-NEXT: [[CALL1:%.*]] = call i32 @callee2.specialized.3(i32 1) +; CHECK-NEXT: [[CALL2:%.*]] = call i32 @callee2.specialized.4(i32 0) +; CHECK-NEXT: [[COND:%.*]] = icmp eq i32 undef, 0 +; CHECK-NEXT: br i1 [[COND]], label %[[COMMON_RET:.*]], label %[[IF_THEN:.*]] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret i32 0 +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: [[UNREACHABLE_CALL:%.*]] = call i32 @callee2.specialized.7(i32 2) +; CHECK-NEXT: ret i32 undef +; +entry: + call void @other_user(ptr @callee2) + %call1 = call i32 @callee2(i32 1) + %call2 = call i32 @callee2(i32 0) + %cond = icmp eq i32 %call2, 0 + br i1 %cond, label %common.ret, label %if.then + +common.ret: ; preds = %entry + ret i32 0 + +if.then: ; preds = %entry + %unreachable_call = call i32 @callee2(i32 2) + ret i32 %unreachable_call +} + +define internal i32 @callee2(i32 %arg) { +; CHECK-LABEL: define internal i32 @callee2( +; CHECK-SAME: i32 [[ARG:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[COND:%.*]] = icmp eq i32 [[ARG]], 1 +; CHECK-NEXT: br i1 [[COND]], label %[[EXIT:.*]], label %[[LOOP]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %loop + +loop: ; preds = %ai, %entry + %add = or i32 0, 0 + %cond = icmp eq i32 %arg, 1 + br i1 %cond, label %exit, label %loop + +exit: ; preds = %ai + ret i32 0 +} + +define i32 @caller3(i32 %arg) { +; CHECK-LABEL: define range(i32 2, 1) i32 @caller3( +; CHECK-SAME: i32 [[ARG:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[CALL1:%.*]] = call i32 @callee3.specialized.5(i32 0) +; CHECK-NEXT: [[CALL2:%.*]] = call i32 @callee3.specialized.6(i32 1) +; CHECK-NEXT: [[COND:%.*]] = icmp eq i32 undef, 0 +; CHECK-NEXT: br i1 [[COND]], label %[[COMMON_RET:.*]], label %[[IF_THEN:.*]] +; CHECK: [[COMMON_RET]]: +; CHECK-NEXT: ret i32 0 +; CHECK: [[IF_THEN]]: +; CHECK-NEXT: ret i32 poison +; +entry: + %call1 = call i32 @callee3(i32 0) + %call2 = call i32 @callee3(i32 1) + %cond = icmp eq i32 %call2, 0 + br i1 %cond, label %common.ret, label %if.then + +common.ret: ; preds = %entry + ret i32 0 + +if.then: ; preds = %entry + %unreachable_call = call i32 @callee3(i32 %arg) + ret i32 %unreachable_call +} + +define internal i32 @callee3(i32 %arg) { +entry: + br label %loop + +loop: ; preds = %ai, %entry + %add = or i32 0, 0 + %cond = icmp ne i32 %arg, 1 + br i1 %cond, label %exit, label %loop + +exit: ; preds = %ai + ret i32 %arg +}