Skip to content

Commit

Permalink
[RISCV] Implement KCFI operand bundle lowering
Browse files Browse the repository at this point in the history
With `-fsanitize=kcfi` (Kernel Control-Flow Integrity), Clang emits
"kcfi" operand bundles to indirect call instructions. Similarly to
the target-specific lowering added in D119296, implement KCFI operand
bundle lowering for RISC-V.

This patch disables the generic KCFI pass for RISC-V in Clang, and
adds the KCFI machine function pass in `RISCVPassConfig::addPreSched`
to emit target-specific `KCFI_CHECK` pseudo instructions before calls
that have KCFI operand bundles. The machine function pass also bundles
the instructions to ensure we emit the checks immediately before the
calls, which is not possible with the generic pass.

`KCFI_CHECK` instructions are lowered in `RISCVAsmPrinter` to a
contiguous code sequence that traps if the expected hash in the
operand bundle doesn't match the hash before the target function
address. This patch emits an `ebreak` instruction for error handling
to match the Linux kernel's `BUG()` implementation. Just like for X86,
we also emit trap locations to a `.kcfi_traps` section to support
error handling, as we cannot embed additional information to the trap
instruction itself.

Relands commit 62fa708 with fixed
tests.

Reviewed By: MaskRay

Differential Revision: https://reviews.llvm.org/D148385
  • Loading branch information
samitolvanen committed Jun 23, 2023
1 parent 4d60c65 commit 83835e2
Show file tree
Hide file tree
Showing 14 changed files with 356 additions and 2 deletions.
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/BackendUtil.cpp
Expand Up @@ -631,7 +631,7 @@ static void addKCFIPass(const Triple &TargetTriple, const LangOptions &LangOpts,
PassBuilder &PB) {
// If the back-end supports KCFI operand bundle lowering, skip KCFIPass.
if (TargetTriple.getArch() == llvm::Triple::x86_64 ||
TargetTriple.isAArch64(64))
TargetTriple.isAArch64(64) || TargetTriple.isRISCV())
return;

// Ensure we lower KCFI operand bundles with -O0.
Expand Down
91 changes: 91 additions & 0 deletions llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
Expand Up @@ -19,6 +19,7 @@
#include "RISCVMachineFunctionInfo.h"
#include "RISCVTargetMachine.h"
#include "TargetInfo/RISCVTargetInfo.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/BinaryFormat/ELF.h"
#include "llvm/CodeGen/AsmPrinter.h"
Expand Down Expand Up @@ -72,6 +73,7 @@ class RISCVAsmPrinter : public AsmPrinter {
typedef std::tuple<unsigned, uint32_t> HwasanMemaccessTuple;
std::map<HwasanMemaccessTuple, MCSymbol *> HwasanMemaccessSymbols;
void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI);
void LowerKCFI_CHECK(const MachineInstr &MI);
void EmitHwasanMemaccessSymbols(Module &M);

// Wrapper needed for tblgenned pseudo lowering.
Expand Down Expand Up @@ -150,6 +152,9 @@ void RISCVAsmPrinter::emitInstruction(const MachineInstr *MI) {
case RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES:
LowerHWASAN_CHECK_MEMACCESS(*MI);
return;
case RISCV::KCFI_CHECK:
LowerKCFI_CHECK(*MI);
return;
case RISCV::PseudoRVVInitUndefM1:
case RISCV::PseudoRVVInitUndefM2:
case RISCV::PseudoRVVInitUndefM4:
Expand Down Expand Up @@ -305,6 +310,92 @@ void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) {
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr));
}

void RISCVAsmPrinter::LowerKCFI_CHECK(const MachineInstr &MI) {
Register AddrReg = MI.getOperand(0).getReg();
assert(std::next(MI.getIterator())->isCall() &&
"KCFI_CHECK not followed by a call instruction");
assert(std::next(MI.getIterator())->getOperand(0).getReg() == AddrReg &&
"KCFI_CHECK call target doesn't match call operand");

// Temporary registers for comparing the hashes. If a register is used
// for the call target, or reserved by the user, we can clobber another
// temporary register as the check is immediately followed by the
// call. The check defaults to X6/X7, but can fall back to X28-X31 if
// needed.
unsigned ScratchRegs[] = {RISCV::X6, RISCV::X7};
unsigned NextReg = RISCV::X28;
auto isRegAvailable = [&](unsigned Reg) {
return Reg != AddrReg && !STI->isRegisterReservedByUser(Reg);
};
for (auto &Reg : ScratchRegs) {
if (isRegAvailable(Reg))
continue;
while (!isRegAvailable(NextReg))
++NextReg;
Reg = NextReg++;
if (Reg > RISCV::X31)
report_fatal_error("Unable to find scratch registers for KCFI_CHECK");
}

if (AddrReg == RISCV::X0) {
// Checking X0 makes no sense. Instead of emitting a load, zero
// ScratchRegs[0].
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::ADDI)
.addReg(ScratchRegs[0])
.addReg(RISCV::X0)
.addImm(0));
} else {
// Adjust the offset for patchable-function-prefix. This assumes that
// patchable-function-prefix is the same for all functions.
int NopSize = STI->hasStdExtCOrZca() ? 2 : 4;
int64_t PrefixNops = 0;
(void)MI.getMF()
->getFunction()
.getFnAttribute("patchable-function-prefix")
.getValueAsString()
.getAsInteger(10, PrefixNops);

// Load the target function type hash.
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::LW)
.addReg(ScratchRegs[0])
.addReg(AddrReg)
.addImm(-(PrefixNops * NopSize + 4)));
}

// Load the expected 32-bit type hash.
const int64_t Type = MI.getOperand(1).getImm();
const int64_t Hi20 = ((Type + 0x800) >> 12) & 0xFFFFF;
const int64_t Lo12 = SignExtend64<12>(Type);
if (Hi20) {
EmitToStreamer(
*OutStreamer,
MCInstBuilder(RISCV::LUI).addReg(ScratchRegs[1]).addImm(Hi20));
}
if (Lo12 || Hi20 == 0) {
EmitToStreamer(*OutStreamer,
MCInstBuilder((STI->hasFeature(RISCV::Feature64Bit) && Hi20)
? RISCV::ADDIW
: RISCV::ADDI)
.addReg(ScratchRegs[1])
.addReg(ScratchRegs[1])
.addImm(Lo12));
}

// Compare the hashes and trap if there's a mismatch.
MCSymbol *Pass = OutContext.createTempSymbol();
EmitToStreamer(*OutStreamer,
MCInstBuilder(RISCV::BEQ)
.addReg(ScratchRegs[0])
.addReg(ScratchRegs[1])
.addExpr(MCSymbolRefExpr::create(Pass, OutContext)));

MCSymbol *Trap = OutContext.createTempSymbol();
OutStreamer->emitLabel(Trap);
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::EBREAK));
emitKCFITrapEntry(*MI.getMF(), Trap);
OutStreamer->emitLabel(Pass);
}

void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
if (HwasanMemaccessSymbols.empty())
return;
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -15395,17 +15395,24 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
if (Glue.getNode())
Ops.push_back(Glue);

assert((!CLI.CFIType || CLI.CB->isIndirectCall()) &&
"Unexpected CFI type for a direct call");

// Emit the call.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);

if (IsTailCall) {
MF.getFrameInfo().setHasTailCall();
SDValue Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
if (CLI.CFIType)
Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
return Ret;
}

Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
if (CLI.CFIType)
Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
Glue = Chain.getValue(1);

Expand Down Expand Up @@ -16864,6 +16871,24 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
return true;
}

MachineInstr *
RISCVTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
MachineBasicBlock::instr_iterator &MBBI,
const TargetInstrInfo *TII) const {
assert(MBBI->isCall() && MBBI->getCFIType() &&
"Invalid call instruction for a KCFI check");
assert(is_contained({RISCV::PseudoCALLIndirect, RISCV::PseudoTAILIndirect},
MBBI->getOpcode()));

MachineOperand &Target = MBBI->getOperand(0);
Target.setIsRenamable(false);

return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(RISCV::KCFI_CHECK))
.addReg(Target.getReg())
.addImm(MBBI->getCFIType())
.getInstr();
}

#define GET_REGISTER_MATCHER
#include "RISCVGenAsmMatcher.inc"

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -759,6 +759,12 @@ class RISCVTargetLowering : public TargetLowering {
bool lowerInterleavedStore(StoreInst *SI, ShuffleVectorInst *SVI,
unsigned Factor) const override;

bool supportKCFIBundles() const override { return true; }

MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB,
MachineBasicBlock::instr_iterator &MBBI,
const TargetInstrInfo *TII) const override;

/// RISCVCCAssignFn - This target-specific function extends the default
/// CCValAssign with additional information used to lower RISC-V calling
/// conventions.
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Expand Up @@ -1265,13 +1265,27 @@ unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
}
}

if (Opcode == TargetOpcode::BUNDLE)
return getInstBundleLength(MI);

if (MI.getParent() && MI.getParent()->getParent()) {
if (isCompressibleInst(MI, STI))
return 2;
}
return get(Opcode).getSize();
}

unsigned RISCVInstrInfo::getInstBundleLength(const MachineInstr &MI) const {
unsigned Size = 0;
MachineBasicBlock::const_instr_iterator I = MI.getIterator();
MachineBasicBlock::const_instr_iterator E = MI.getParent()->instr_end();
while (++I != E && I->isInsideBundle()) {
assert(!I->isBundle() && "No nested bundle!");
Size += getInstSizeInBytes(*I);
}
return Size;
}

bool RISCVInstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const {
const unsigned Opcode = MI.getOpcode();
switch (Opcode) {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.h
Expand Up @@ -237,6 +237,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {

protected:
const RISCVSubtarget &STI;

private:
unsigned getInstBundleLength(const MachineInstr &MI) const;
};

namespace RISCV {
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Expand Up @@ -1887,6 +1887,13 @@ def HWASAN_CHECK_MEMACCESS_SHORTGRANULES
[(int_hwasan_check_memaccess_shortgranules X5, GPRJALR:$ptr,
(i32 timm:$accessinfo))]>;

// This gets lowered into a 20-byte instruction sequence (at most)
let hasSideEffects = 0, mayLoad = 1, mayStore = 0,
Defs = [ X6, X7, X28, X29, X30, X31 ], Size = 20 in {
def KCFI_CHECK
: Pseudo<(outs), (ins GPRJALR:$ptr, i32imm:$type), []>, Sched<[]>;
}

/// Simple optimization
def : Pat<(XLenVT (add GPR:$rs1, (AddiPair:$rs2))),
(ADDI (ADDI GPR:$rs1, (AddiPairImmLarge AddiPair:$rs2)),
Expand Down
11 changes: 10 additions & 1 deletion llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Expand Up @@ -76,6 +76,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
RegisterTargetMachine<RISCVTargetMachine> Y(getTheRISCV64Target());
auto *PR = PassRegistry::getPassRegistry();
initializeGlobalISel(*PR);
initializeKCFIPass(*PR);
initializeRISCVMakeCompressibleOptPass(*PR);
initializeRISCVGatherScatterLoweringPass(*PR);
initializeRISCVCodeGenPreparePass(*PR);
Expand Down Expand Up @@ -333,7 +334,10 @@ bool RISCVPassConfig::addGlobalInstructionSelect() {
return false;
}

void RISCVPassConfig::addPreSched2() {}
void RISCVPassConfig::addPreSched2() {
// Emit KCFI checks for indirect calls.
addPass(createKCFIPass());
}

void RISCVPassConfig::addPreEmitPass() {
addPass(&BranchRelaxationPassID);
Expand All @@ -357,6 +361,11 @@ void RISCVPassConfig::addPreEmitPass2() {
// possibility for other passes to break the requirements for forward
// progress in the LR/SC block.
addPass(createRISCVExpandAtomicPseudoPass());

// KCFI indirect call checks are lowered to a bundle.
addPass(createUnpackMachineBundles([&](const MachineFunction &MF) {
return MF.getFunction().getParent()->getModuleFlag("kcfi");
}));
}

void RISCVPassConfig::addMachineSSAOptimization() {
Expand Down
2 changes: 2 additions & 0 deletions llvm/test/CodeGen/RISCV/O0-pipeline.ll
Expand Up @@ -51,6 +51,7 @@
; CHECK-NEXT: Machine Optimization Remark Emitter
; CHECK-NEXT: Prologue/Epilogue Insertion & Frame Finalization
; CHECK-NEXT: Post-RA pseudo instruction expansion pass
; CHECK-NEXT: Insert KCFI indirect call checks
; CHECK-NEXT: Analyze Machine Code For Garbage Collection
; CHECK-NEXT: Insert fentry calls
; CHECK-NEXT: Insert XRay ops
Expand All @@ -66,6 +67,7 @@
; CHECK-NEXT: Stack Frame Layout Analysis
; CHECK-NEXT: RISC-V pseudo instruction expansion pass
; CHECK-NEXT: RISC-V atomic pseudo instruction expansion pass
; CHECK-NEXT: Unpack machine instruction bundles
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
; CHECK-NEXT: Machine Optimization Remark Emitter
; CHECK-NEXT: RISC-V Assembly Printer
Expand Down
2 changes: 2 additions & 0 deletions llvm/test/CodeGen/RISCV/O3-pipeline.ll
Expand Up @@ -155,6 +155,7 @@
; CHECK-NEXT: Tail Duplication
; CHECK-NEXT: Machine Copy Propagation Pass
; CHECK-NEXT: Post-RA pseudo instruction expansion pass
; CHECK-NEXT: Insert KCFI indirect call checks
; CHECK-NEXT: MachineDominator Tree Construction
; CHECK-NEXT: Machine Natural Loop Construction
; CHECK-NEXT: Post RA top-down list latency scheduler
Expand All @@ -180,6 +181,7 @@
; CHECK-NEXT: RISC-V Zcmp move merging pass
; CHECK-NEXT: RISC-V pseudo instruction expansion pass
; CHECK-NEXT: RISC-V atomic pseudo instruction expansion pass
; CHECK-NEXT: Unpack machine instruction bundles
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
; CHECK-NEXT: Machine Optimization Remark Emitter
; CHECK-NEXT: RISC-V Assembly Printer
Expand Down
33 changes: 33 additions & 0 deletions llvm/test/CodeGen/RISCV/kcfi-isel-mir.ll
@@ -0,0 +1,33 @@
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=riscv64 -stop-after=finalize-isel -verify-machineinstrs -o - %s | FileCheck %s
define void @f1(ptr noundef %x) !kcfi_type !1 {
; CHECK-LABEL: name: f1
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $x10
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:gprjalr = COPY $x10
; CHECK-NEXT: ADJCALLSTACKDOWN 0, 0, implicit-def dead $x2, implicit $x2
; CHECK-NEXT: PseudoCALLIndirect [[COPY]], csr_ilp32_lp64, implicit-def dead $x1, implicit-def $x2, cfi-type 12345678
; CHECK-NEXT: ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
; CHECK-NEXT: PseudoRET
call void %x() [ "kcfi"(i32 12345678) ]
ret void
}

define void @f2(ptr noundef %x) #0 {
; CHECK-LABEL: name: f2
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $x10
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:gprtc = COPY $x10
; CHECK-NEXT: PseudoTAILIndirect [[COPY]], implicit $x2, cfi-type 12345678
tail call void %x() [ "kcfi"(i32 12345678) ]
ret void
}

attributes #0 = { "patchable-function-entry"="2" }

!llvm.module.flags = !{!0}

!0 = !{i32 4, !"kcfi", i32 1}
!1 = !{i32 12345678}
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/RISCV/kcfi-mir.ll
@@ -0,0 +1,42 @@
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=riscv64 -stop-after=kcfi -verify-machineinstrs -o - %s | FileCheck %s

define void @f1(ptr noundef %x) !kcfi_type !1 {
; CHECK-LABEL: name: f1
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $x10, $x1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: $x2 = frame-setup ADDI $x2, -16
; CHECK-NEXT: frame-setup CFI_INSTRUCTION def_cfa_offset 16
; CHECK-NEXT: SD killed $x1, $x2, 8 :: (store (s64) into %stack.0)
; CHECK-NEXT: frame-setup CFI_INSTRUCTION offset $x1, -8
; CHECK-NEXT: BUNDLE implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31, implicit-def dead $x1, implicit-def $x2, implicit killed $x10 {
; CHECK-NEXT: KCFI_CHECK $x10, 12345678, implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31
; CHECK-NEXT: PseudoCALLIndirect killed $x10, csr_ilp32_lp64, implicit-def dead $x1, implicit-def $x2
; CHECK-NEXT: }
; CHECK-NEXT: $x1 = LD $x2, 8 :: (load (s64) from %stack.0)
; CHECK-NEXT: $x2 = frame-destroy ADDI $x2, 16
; CHECK-NEXT: PseudoRET
call void %x() [ "kcfi"(i32 12345678) ]
ret void
}

define void @f2(ptr noundef %x) #0 {
; CHECK-LABEL: name: f2
; CHECK: bb.0 (%ir-block.0):
; CHECK-NEXT: liveins: $x10
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: BUNDLE implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31, implicit killed $x10, implicit $x2 {
; CHECK-NEXT: KCFI_CHECK $x10, 12345678, implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31
; CHECK-NEXT: PseudoTAILIndirect killed $x10, implicit $x2
; CHECK-NEXT: }
tail call void %x() [ "kcfi"(i32 12345678) ]
ret void
}

attributes #0 = { "patchable-function-entry"="2" }

!llvm.module.flags = !{!0}

!0 = !{i32 4, !"kcfi", i32 1}
!1 = !{i32 12345678}

0 comments on commit 83835e2

Please sign in to comment.