diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 19ef1f2f18ec1..18b5f6f929288 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -21547,6 +21547,20 @@ MCPhysReg RVVArgDispatcher::getNextPhysReg() { return AllocatedPhysRegs[CurIdx++]; } +SDValue RISCVTargetLowering::expandIndirectJTBranch(const SDLoc &dl, + SDValue Value, SDValue Addr, + int JTI, + SelectionDAG &DAG) const { + if (Subtarget.hasStdExtZicfilp()) { + // When Zicfilp enabled, we need to use software guarded branch for jump + // table branch. + SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Value, dl); + return DAG.getNode(RISCVISD::SW_GUARDED_BRIND, dl, MVT::Other, JTInfo, + Addr); + } + return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, JTI, DAG); +} + namespace llvm::RISCVVIntrinsicsTable { #define GET_RISCVVIntrinsicsTable_IMPL diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 78f99e70c083a..afc317f94daef 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -400,6 +400,10 @@ enum NodeType : unsigned { CZERO_EQZ, // vt.maskc for XVentanaCondOps. CZERO_NEZ, // vt.maskcn for XVentanaCondOps. + /// Software guarded BRIND node. Operand 0 is the chain operand and + /// operand 1 is the target address. + SW_GUARDED_BRIND, + // FP to 32 bit int conversions for RV64. These are used to keep track of the // result being sign extended to 64 bit. These saturate out of range inputs. STRICT_FCVT_W_RV64 = ISD::FIRST_TARGET_STRICTFP_OPCODE, @@ -869,6 +873,9 @@ class RISCVTargetLowering : public TargetLowering { bool supportKCFIBundles() const override { return true; } + SDValue expandIndirectJTBranch(const SDLoc &dl, SDValue Value, SDValue Addr, + int JTI, SelectionDAG &DAG) const override; + MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB, MachineBasicBlock::instr_iterator &MBBI, const TargetInstrInfo *TII) const override; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index b867eccf42664..9d574edb4e6d1 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -69,6 +69,8 @@ def riscv_brcc : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC, def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall, [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, SDNPVariadic]>; +def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND", + SDTBrind, [SDNPHasChain]>; def riscv_sllw : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>; def riscv_sraw : SDNode<"RISCVISD::SRAW", SDT_RISCVIntBinOpW>; def riscv_srlw : SDNode<"RISCVISD::SRLW", SDT_RISCVIntBinOpW>; @@ -1454,9 +1456,12 @@ def PseudoBRIND : Pseudo<(outs), (ins GPRJALR:$rs1, simm12:$imm12), []>, PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>; let Predicates = [HasStdExtZicfilp], - isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in + isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in { def PseudoBRINDNonX7 : Pseudo<(outs), (ins GPRJALRNonX7:$rs1, simm12:$imm12), []>, PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>; +def PseudoBRINDX7 : Pseudo<(outs), (ins GPRX7:$rs1, simm12:$imm12), []>, + PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>; +} // For Zicfilp, need to avoid using X7/T2 for indirect branches which need // landing pad. @@ -1464,6 +1469,10 @@ let Predicates = [HasStdExtZicfilp] in { def : Pat<(brind GPRJALRNonX7:$rs1), (PseudoBRINDNonX7 GPRJALRNonX7:$rs1, 0)>; def : Pat<(brind (add GPRJALRNonX7:$rs1, simm12:$imm12)), (PseudoBRINDNonX7 GPRJALRNonX7:$rs1, simm12:$imm12)>; + +def : Pat<(riscv_sw_guarded_brind GPRX7:$rs1), (PseudoBRINDX7 GPRX7:$rs1, 0)>; +def : Pat<(riscv_sw_guarded_brind (add GPRX7:$rs1, simm12:$imm12)), + (PseudoBRINDX7 GPRX7:$rs1, simm12:$imm12)>; } let Predicates = [NoStdExtZicfilp] in { diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td index 90e62dc39e6a8..b12634c24622f 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td @@ -167,6 +167,8 @@ def GPRNoX0 : GPRRegisterClass<(sub GPR, X0)>; def GPRNoX0X2 : GPRRegisterClass<(sub GPR, X0, X2)>; +def GPRX7 : GPRRegisterClass<(add X7)>; + // Don't use X1 or X5 for JALR since that is a hint to pop the return address // stack on some microarchitectures. Also remove the reserved registers X0, X2, // X3, and X4 as it reduces the number of register classes that get synthesized diff --git a/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll new file mode 100644 index 0000000000000..9d57ca74cd78a --- /dev/null +++ b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll @@ -0,0 +1,105 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple riscv32 -mattr=+experimental-zicfilp < %s | FileCheck %s +; RUN: llc -mtriple riscv64 -mattr=+experimental-zicfilp < %s | FileCheck %s +; RUN: llc -mtriple riscv32 < %s | FileCheck %s --check-prefix=NO-ZICFILP +; RUN: llc -mtriple riscv64 < %s | FileCheck %s --check-prefix=NO-ZICFILP + +; Test using t2 to jump table branch. +define void @above_threshold(i32 signext %in, ptr %out) nounwind { +; CHECK-LABEL: above_threshold: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: addi a0, a0, -1 +; CHECK-NEXT: li a2, 5 +; CHECK-NEXT: bltu a2, a0, .LBB0_9 +; CHECK-NEXT: # %bb.1: # %entry +; CHECK-NEXT: slli a0, a0, 2 +; CHECK-NEXT: lui a2, %hi(.LJTI0_0) +; CHECK-NEXT: addi a2, a2, %lo(.LJTI0_0) +; CHECK-NEXT: add a0, a0, a2 +; CHECK-NEXT: lw t2, 0(a0) +; CHECK-NEXT: jr t2 +; CHECK-NEXT: .LBB0_2: # %bb1 +; CHECK-NEXT: li a0, 4 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_3: # %bb5 +; CHECK-NEXT: li a0, 100 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_4: # %bb3 +; CHECK-NEXT: li a0, 2 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_5: # %bb4 +; CHECK-NEXT: li a0, 1 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_6: # %bb2 +; CHECK-NEXT: li a0, 3 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_7: # %bb6 +; CHECK-NEXT: li a0, 200 +; CHECK-NEXT: .LBB0_8: # %exit +; CHECK-NEXT: sw a0, 0(a1) +; CHECK-NEXT: .LBB0_9: # %exit +; CHECK-NEXT: ret +; +; NO-ZICFILP-LABEL: above_threshold: +; NO-ZICFILP: # %bb.0: # %entry +; NO-ZICFILP-NEXT: addi a0, a0, -1 +; NO-ZICFILP-NEXT: li a2, 5 +; NO-ZICFILP-NEXT: bltu a2, a0, .LBB0_9 +; NO-ZICFILP-NEXT: # %bb.1: # %entry +; NO-ZICFILP-NEXT: slli a0, a0, 2 +; NO-ZICFILP-NEXT: lui a2, %hi(.LJTI0_0) +; NO-ZICFILP-NEXT: addi a2, a2, %lo(.LJTI0_0) +; NO-ZICFILP-NEXT: add a0, a0, a2 +; NO-ZICFILP-NEXT: lw a0, 0(a0) +; NO-ZICFILP-NEXT: jr a0 +; NO-ZICFILP-NEXT: .LBB0_2: # %bb1 +; NO-ZICFILP-NEXT: li a0, 4 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_3: # %bb5 +; NO-ZICFILP-NEXT: li a0, 100 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_4: # %bb3 +; NO-ZICFILP-NEXT: li a0, 2 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_5: # %bb4 +; NO-ZICFILP-NEXT: li a0, 1 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_6: # %bb2 +; NO-ZICFILP-NEXT: li a0, 3 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_7: # %bb6 +; NO-ZICFILP-NEXT: li a0, 200 +; NO-ZICFILP-NEXT: .LBB0_8: # %exit +; NO-ZICFILP-NEXT: sw a0, 0(a1) +; NO-ZICFILP-NEXT: .LBB0_9: # %exit +; NO-ZICFILP-NEXT: ret +entry: + switch i32 %in, label %exit [ + i32 1, label %bb1 + i32 2, label %bb2 + i32 3, label %bb3 + i32 4, label %bb4 + i32 5, label %bb5 + i32 6, label %bb6 + ] +bb1: + store i32 4, ptr %out + br label %exit +bb2: + store i32 3, ptr %out + br label %exit +bb3: + store i32 2, ptr %out + br label %exit +bb4: + store i32 1, ptr %out + br label %exit +bb5: + store i32 100, ptr %out + br label %exit +bb6: + store i32 200, ptr %out + br label %exit +exit: + ret void +}