diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 700ab797b2f69f..126a27167d2dd2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -56365,12 +56365,21 @@ SDValue X86TargetLowering::expandIndirectJTBranch(const SDLoc &dl, int JTI, SelectionDAG &DAG) const { const Module *M = DAG.getMachineFunction().getMMI().getModule(); - Metadata *IsCFProtectionSupported = M->getModuleFlag("cf-protection-branch"); - if (IsCFProtectionSupported) { - // In case control-flow branch protection is enabled, we need to add - // notrack prefix to the indirect branch. - // In order to do that we create NT_BRIND SDNode. - // Upon ISEL, the pattern will convert it to jmp with NoTrack prefix. + + uint64_t CFProtectionBranchLevel = 0; + if (Metadata *CFProtectionBranchEnabled = + M->getModuleFlag("cf-protection-branch")) + CFProtectionBranchLevel = + cast(CFProtectionBranchEnabled) + ->getValue() + ->getUniqueInteger() + .getLimitedValue(); + + if (CFProtectionBranchLevel == 1) { + // In case control-flow branch protection is enabled but we are not + // protecting jump table branches, we need to add notrack prefix to the + // indirect branch. In order to do that we create NT_BRIND SDNode. Upon + // ISEL, the pattern will convert it to jmp with NoTrack prefix. SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Value, dl); return DAG.getNode(X86ISD::NT_BRIND, dl, MVT::Other, JTInfo, Addr); } diff --git a/llvm/lib/Target/X86/X86IndirectBranchTracking.cpp b/llvm/lib/Target/X86/X86IndirectBranchTracking.cpp index 785bdd83cd998b..233ffcf4df4121 100644 --- a/llvm/lib/Target/X86/X86IndirectBranchTracking.cpp +++ b/llvm/lib/Target/X86/X86IndirectBranchTracking.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" using namespace llvm; @@ -117,7 +118,13 @@ bool X86IndirectBranchTrackingPass::runOnMachineFunction(MachineFunction &MF) { const Module *M = MF.getMMI().getModule(); // Check that the cf-protection-branch is enabled. - Metadata *isCFProtectionSupported = M->getModuleFlag("cf-protection-branch"); + uint64_t CFProtectionLevel = 0; + if (Metadata *isCFProtectionSupported = + M->getModuleFlag("cf-protection-branch")) + CFProtectionLevel = cast(isCFProtectionSupported) + ->getValue() + ->getUniqueInteger() + .getLimitedValue(); // NB: We need to enable IBT in jitted code if JIT compiler is CET // enabled. @@ -128,7 +135,7 @@ bool X86IndirectBranchTrackingPass::runOnMachineFunction(MachineFunction &MF) { #else bool isJITwithCET = false; #endif - if (!isCFProtectionSupported && !IndirectBranchTracking && !isJITwithCET) + if (!CFProtectionLevel && !IndirectBranchTracking && !isJITwithCET) return false; // True if the current MF was changed and false otherwise. @@ -186,5 +193,15 @@ bool X86IndirectBranchTrackingPass::runOnMachineFunction(MachineFunction &MF) { } } } + + // If strong CF protections are enabled (Level > 1), then add ENDBRs to all + // jump table BBs, since jump table branches will not have the 'notrack' + // prefix. + if (CFProtectionLevel > 1) + if (const MachineJumpTableInfo *JTI = MF.getJumpTableInfo()) + for (const MachineJumpTableEntry &JTE : JTI->getJumpTables()) + for (MachineBasicBlock *MBB : JTE.MBBs) + Changed |= addENDBR(*MBB, MBB->begin()); + return Changed; } diff --git a/llvm/test/CodeGen/X86/indirect-branch-tracking-jt.ll b/llvm/test/CodeGen/X86/indirect-branch-tracking-jt.ll new file mode 100644 index 00000000000000..0f4ad8695f583a --- /dev/null +++ b/llvm/test/CodeGen/X86/indirect-branch-tracking-jt.ll @@ -0,0 +1,46 @@ +; RUN: sed 's/level/0/g' %s | llc -o - -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefixes=ALL,LEVEL0 +; RUN: sed 's/level/1/g' %s | llc -o - -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefixes=ALL,LEVEL1 +; RUN: sed 's/level/2/g' %s | llc -o - -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefixes=ALL,LEVEL2 +; Check the following for IBT protection of jump tables: +; - if cf-protection-branch == 0, then no ENDBRANCH instructions are inserted and the indirect jump table branch does not have a NOTRACK prefix. +; - if cf-protection-branch == 1, then an ENDBRANCH is inserted at function entry but *not* in any jump table BB, and the indirect jump table branch *has* a NOTRACK prefix. +; - if cf-protection-branch >= 2, then an ENDBRANCH is inserted at function antry *and* in the jump table BBs, and the indirect jump table branch *does not have* a NOTRACK prefix. + +define void @foo(i32 %x) { +; ALL-LABEL: foo +; LEVEL0-NOT: endbr64 +; LEVEL1: endbr64 +; LEVEL2: endbr64 +; LEVEL0: jmpq * +; LEVEL1: notrack jmpq * +; LEVEL2: jmpq * +; ALL-LABEL: .LBB0_2: +; LEVEL0-NOT: endbr64 +; LEVEL1-NOT: endbr64 +; LEVEL2: endbr64 +; ALL: retq + switch i32 %x, label %sw.default [ + i32 0, label %sw.bb + i32 1, label %sw.bb1 + i32 2, label %sw.bb2 + i32 3, label %sw.bb3 + ] + +sw.bb: + ret void + +sw.bb1: + ret void + +sw.bb2: + ret void + +sw.bb3: + ret void + +sw.default: + ret void +} + +!llvm.module.flags = !{!0} +!0 = !{i32 8, !"cf-protection-branch", i32 level}