diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h index daa88981d3bf6..fcb384ec36133 100644 --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -19,6 +19,7 @@ class CallBase; class CastInst; class Function; class MDNode; +class Value; /// Return true if the given indirect call site can be made to call \p Callee. /// @@ -73,6 +74,15 @@ CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, /// bool tryPromoteCall(CallBase &CB); +/// Predicate and clone the given call site. +/// +/// This function creates an if-then-else structure at the location of the call +/// site. The "if" condition compares the call site's called value to the given +/// callee. The original call site is moved into the "else" block, and a clone +/// of the call site is placed in the "then" block. The cloned instruction is +/// returned. +CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights); + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index fab080f3e133a..d4b669e72460f 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -79,6 +79,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndexYAML.h" @@ -95,6 +96,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Evaluator.h" #include #include @@ -163,13 +165,19 @@ static cl::list cl::desc("Prevent function(s) from being devirtualized"), cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated); -/// Mechanism to add runtime checking of devirtualization decisions, trapping on -/// any that are not correct. Useful for debugging undefined behavior leading to -/// failures with WPD. -static cl::opt - CheckDevirt("wholeprogramdevirt-check", cl::init(false), cl::Hidden, - cl::ZeroOrMore, - cl::desc("Add code to trap on incorrect devirtualizations")); +/// Mechanism to add runtime checking of devirtualization decisions, optionally +/// trapping or falling back to indirect call on any that are not correct. +/// Trapping mode is useful for debugging undefined behavior leading to failures +/// with WPD. Fallback mode is useful for ensuring safety when whole program +/// visibility may be compromised. +enum WPDCheckMode { None, Trap, Fallback }; +static cl::opt DevirtCheckMode( + "wholeprogramdevirt-check", cl::Hidden, cl::ZeroOrMore, + cl::desc("Type of checking for incorrect devirtualizations"), + cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), + clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), + clEnumValN(WPDCheckMode::Fallback, "fallback", + "Fallback to indirect when incorrect"))); namespace { struct PatternList { @@ -1140,10 +1148,10 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, Value *Callee = Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); - // If checking is enabled, add support to compare the virtual function - // pointer to the devirtualized target. In case of a mismatch, perform a - // debug trap. - if (CheckDevirt) { + // If trap checking is enabled, add support to compare the virtual + // function pointer to the devirtualized target. In case of a mismatch, + // perform a debug trap. + if (DevirtCheckMode == WPDCheckMode::Trap) { auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); Instruction *ThenTerm = SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false); @@ -1153,8 +1161,38 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, CallTrap->setDebugLoc(CB.getDebugLoc()); } - // Devirtualize. - CB.setCalledOperand(Callee); + // If fallback checking is enabled, add support to compare the virtual + // function pointer to the devirtualized target. In case of a mismatch, + // fall back to indirect call. + if (DevirtCheckMode == WPDCheckMode::Fallback) { + MDNode *Weights = + MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); + // Version the indirect call site. If the called value is equal to the + // given callee, 'NewInst' will be executed, otherwise the original call + // site will be executed. + CallBase &NewInst = versionCallSite(CB, Callee, Weights); + NewInst.setCalledOperand(Callee); + // Since the new call site is direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and + // !callees metadata. + NewInst.setMetadata(LLVMContext::MD_prof, nullptr); + NewInst.setMetadata(LLVMContext::MD_callees, nullptr); + // Additionally, we should remove them from the fallback indirect call, + // so that we don't attempt to perform indirect call promotion later. + CB.setMetadata(LLVMContext::MD_prof, nullptr); + CB.setMetadata(LLVMContext::MD_callees, nullptr); + } + + // In either trapping or non-checking mode, devirtualize original call. + else { + // Devirtualize unconditionally. + CB.setCalledOperand(Callee); + // Since the call site is now direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and + // !callees metadata. + CB.setMetadata(LLVMContext::MD_prof, nullptr); + CB.setMetadata(LLVMContext::MD_callees, nullptr); + } // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 56b6e4bc46a51..e530afc277db3 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -279,8 +279,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// ; The original call instruction stays in its original block. /// %t0 = musttail call i32 %ptr() /// ret %t0 -static CallBase &versionCallSite(CallBase &CB, Value *Callee, - MDNode *BranchWeights) { +CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee, + MDNode *BranchWeights) { IRBuilder<> Builder(&CB); CallBase *OrigInst = &CB; diff --git a/llvm/test/ThinLTO/X86/devirt.ll b/llvm/test/ThinLTO/X86/devirt.ll index 66adec0becef8..9ba1dc2d77d67 100644 --- a/llvm/test/ThinLTO/X86/devirt.ll +++ b/llvm/test/ThinLTO/X86/devirt.ll @@ -154,7 +154,10 @@ entry: ; Check that the call was devirtualized. ; CHECK-IR: %call = tail call i32 @_ZN1A1nEi - %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a) + ; Ensure !prof and !callees metadata for indirect call promotion removed. + ; CHECK-IR-NOT: prof + ; CHECK-IR-NOT: callees + %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a), !prof !5, !callees !6 %3 = bitcast i8** %vtable to i32 (%struct.A*, i32)** %fptr22 = load i32 (%struct.A*, i32)*, i32 (%struct.A*, i32)** %3, align 8 @@ -207,3 +210,5 @@ attributes #0 = { noinline optnone } !2 = !{i64 16, !"_ZTS1C"} !3 = !{i64 16, !4} !4 = distinct !{} +!5 = !{!"VP", i32 0, i64 1, i64 1621563287929432257, i64 1} +!6 = !{i32 (%struct.A*, i32)* @_ZN1A1nEi} diff --git a/llvm/test/ThinLTO/X86/devirt_check.ll b/llvm/test/ThinLTO/X86/devirt_check.ll index 0ede1e1f5ddb3..a16c828ae94ca 100644 --- a/llvm/test/ThinLTO/X86/devirt_check.ll +++ b/llvm/test/ThinLTO/X86/devirt_check.ll @@ -1,21 +1,33 @@ ; REQUIRES: x86-registered-target ; Test that devirtualization option -wholeprogramdevirt-check adds code to check -; that the devirtualization decision was correct and trap if not. +; that the devirtualization decision was correct and trap or fallback if not. ; The vtables have vcall_visibility metadata with hidden visibility, to enable ; devirtualization. ; Generate unsplit module with summary for ThinLTO index-based WPD. ; RUN: opt -thinlto-bc -o %t2.o %s + +; Check first in trapping mode. ; RUN: llvm-lto2 run %t2.o -save-temps -use-new-pm -pass-remarks=. \ -; RUN: -wholeprogramdevirt-check \ +; RUN: -wholeprogramdevirt-check=trap \ ; RUN: -o %t3 \ ; RUN: -r=%t2.o,test,px \ ; RUN: -r=%t2.o,_ZN1A1nEi,p \ ; RUN: -r=%t2.o,_ZN1B1fEi,p \ ; RUN: -r=%t2.o,_ZTV1B,px 2>&1 | FileCheck %s --check-prefix=REMARK -; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK-IR +; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK --check-prefix=TRAP + +; Check next in fallback mode. +; RUN: llvm-lto2 run %t2.o -save-temps -use-new-pm -pass-remarks=. \ +; RUN: -wholeprogramdevirt-check=fallback \ +; RUN: -o %t3 \ +; RUN: -r=%t2.o,test,px \ +; RUN: -r=%t2.o,_ZN1A1nEi,p \ +; RUN: -r=%t2.o,_ZN1B1fEi,p \ +; RUN: -r=%t2.o,_ZTV1B,px 2>&1 | FileCheck %s --check-prefix=REMARK +; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK --check-prefix=FALLBACK ; REMARK-DAG: single-impl: devirtualized a call to _ZN1A1nEi @@ -28,7 +40,7 @@ target triple = "x86_64-grtev4-linux-gnu" @_ZTV1B = constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* undef, i8* bitcast (i32 (%struct.B*, i32)* @_ZN1B1fEi to i8*), i8* bitcast (i32 (%struct.A*, i32)* @_ZN1A1nEi to i8*)] }, !type !0, !type !1, !vcall_visibility !5 -; CHECK-IR-LABEL: define i32 @test +; CHECK-LABEL: define i32 @test define i32 @test(%struct.A* %obj, i32 %a) { entry: %0 = bitcast %struct.A* %obj to i8*** @@ -42,19 +54,40 @@ entry: ; Check that the call was devirtualized, but preceeded by a check guarding ; a trap if the function pointer doesn't match. - ; CHECK-IR: %.not = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi - ; CHECK-IR: br i1 %.not, label %3, label %2 - ; CHECK-IR: 2: - ; CHECK-IR: tail call void @llvm.debugtrap() - ; CHECK-IR: br label %3 - ; CHECK-IR: 3: - ; CHECK-IR: tail call i32 @_ZN1A1nEi - %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a) + ; TRAP: %.not = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi + ; Ensure !prof and !callees metadata for indirect call promotion removed. + ; TRAP-NOT: prof + ; TRAP-NOT: callees + ; TRAP: br i1 %.not, label %3, label %2 + ; TRAP: 2: + ; TRAP: tail call void @llvm.debugtrap() + ; TRAP: br label %3 + ; TRAP: 3: + ; TRAP: tail call i32 @_ZN1A1nEi + ; Check that the call was devirtualized, but preceeded by a check guarding + ; a fallback if the function pointer doesn't match. + ; FALLBACK: %2 = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi + ; FALLBACK: br i1 %2, label %if.true.direct_targ, label %if.false.orig_indirect + ; FALLBACK: if.true.direct_targ: + ; FALLBACK: tail call i32 @_ZN1A1nEi + ; Ensure !prof and !callees metadata for indirect call promotion removed. + ; FALLBACK-NOT: prof + ; FALLBACK-NOT: callees + ; FALLBACK: br label %if.end.icp + ; FALLBACK: if.false.orig_indirect: + ; FALLBACK: tail call i32 %fptr1 + ; Ensure !prof and !callees metadata for indirect call promotion removed. + ; In particular, if left on the fallback indirect call ICP may perform an + ; additional round of promotion. + ; FALLBACK-NOT: prof + ; FALLBACK-NOT: callees + ; FALLBACK: br label %if.end.icp + %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a), !prof !6, !callees !7 ret i32 %call } -; CHECK-IR-LABEL: ret i32 -; CHECK-IR-LABEL: } +; CHECK-LABEL: ret i32 +; CHECK-LABEL: } declare i1 @llvm.type.test(i8*, metadata) declare void @llvm.assume(i1) @@ -75,3 +108,5 @@ attributes #0 = { noinline optnone } !3 = !{i64 16, !4} !4 = distinct !{} !5 = !{i64 1} +!6 = !{!"VP", i32 0, i64 1, i64 1621563287929432257, i64 1} +!7 = !{i32 (%struct.A*, i32)* @_ZN1A1nEi}