Skip to content

Commit

Permalink
[RISCV] Optimize 2x SELECT for floating-point types
Browse files Browse the repository at this point in the history
Including the following opcode:
 Select_FPR16_Using_CC_GPR
 Select_FPR32_Using_CC_GPR
 Select_FPR64_Using_CC_GPR

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D127871
  • Loading branch information
ChunyuLiao committed Jun 28, 2022
1 parent 1919adb commit 1178992
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 0 deletions.
114 changes: 114 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -9708,6 +9708,109 @@ static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB,
return BB;
}

static MachineBasicBlock *
EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second,
MachineBasicBlock *ThisMBB,
const RISCVSubtarget &Subtarget) {
// Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)
// Without this, custom-inserter would have generated:
//
// A
// | \
// | B
// | /
// C
// | \
// | D
// | /
// E
//
// A: X = ...; Y = ...
// B: empty
// C: Z = PHI [X, A], [Y, B]
// D: empty
// E: PHI [X, C], [Z, D]
//
// If we lower both Select_FPRX_ in a single step, we can instead generate:
//
// A
// | \
// | C
// | /|
// |/ |
// | |
// | D
// | /
// E
//
// A: X = ...; Y = ...
// D: empty
// E: PHI [X, A], [X, C], [Y, D]

const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
const DebugLoc &DL = First.getDebugLoc();
const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
MachineFunction *F = ThisMBB->getParent();
MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
MachineFunction::iterator It = ++ThisMBB->getIterator();
F->insert(It, FirstMBB);
F->insert(It, SecondMBB);
F->insert(It, SinkMBB);

// Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
SinkMBB->splice(SinkMBB->begin(), ThisMBB,
std::next(MachineBasicBlock::iterator(First)),
ThisMBB->end());
SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);

// Fallthrough block for ThisMBB.
ThisMBB->addSuccessor(FirstMBB);
// Fallthrough block for FirstMBB.
FirstMBB->addSuccessor(SecondMBB);
ThisMBB->addSuccessor(SinkMBB);
FirstMBB->addSuccessor(SinkMBB);
// This is fallthrough.
SecondMBB->addSuccessor(SinkMBB);

auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm());
Register FLHS = First.getOperand(1).getReg();
Register FRHS = First.getOperand(2).getReg();
// Insert appropriate branch.
BuildMI(ThisMBB, DL, TII.getBrCond(FirstCC))
.addReg(FLHS)
.addReg(FRHS)
.addMBB(SinkMBB);

Register SLHS = Second.getOperand(1).getReg();
Register SRHS = Second.getOperand(2).getReg();
Register Op1Reg4 = First.getOperand(4).getReg();
Register Op1Reg5 = First.getOperand(5).getReg();

auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm());
// Insert appropriate branch.
BuildMI(FirstMBB, DL, TII.getBrCond(SecondCC))
.addReg(SLHS)
.addReg(SRHS)
.addMBB(SinkMBB);

Register DestReg = Second.getOperand(0).getReg();
Register Op2Reg4 = Second.getOperand(4).getReg();
BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg)
.addReg(Op1Reg4)
.addMBB(ThisMBB)
.addReg(Op2Reg4)
.addMBB(FirstMBB)
.addReg(Op1Reg5)
.addMBB(SecondMBB);

// Now remove the Select_FPRX_s.
First.eraseFromParent();
Second.eraseFromParent();
return SinkMBB;
}

static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
MachineBasicBlock *BB,
const RISCVSubtarget &Subtarget) {
Expand Down Expand Up @@ -9735,6 +9838,10 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
// previous selects in the sequence.
// These conditions could be further relaxed. See the X86 target for a
// related approach and more information.
//
// Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5))
// is checked here and handled by a separate function -
// EmitLoweredCascadedSelect.
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
Expand All @@ -9744,6 +9851,13 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
SelectDests.insert(MI.getOperand(0).getReg());

MachineInstr *LastSelectPseudo = &MI;
auto Next = next_nodbg(MI.getIterator(), BB->instr_end());
if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() &&
Next->getOpcode() == MI.getOpcode() &&
Next->getOperand(5).getReg() == MI.getOperand(0).getReg() &&
Next->getOperand(5).isKill()) {
return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget);
}

for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
SequenceMBBI != E; ++SequenceMBBI) {
Expand Down
76 changes: 76 additions & 0 deletions llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
Expand Up @@ -533,3 +533,79 @@ entry:
%ret = add i32 %cond1, %cond2
ret i32 %ret
}

define float @CascadedSelect(float noundef %a) {
; RV32I-LABEL: CascadedSelect:
; RV32I: # %bb.0: # %entry
; RV32I-NEXT: fmv.w.x ft0, a0
; RV32I-NEXT: fmv.w.x ft1, zero
; RV32I-NEXT: flt.s a0, ft0, ft1
; RV32I-NEXT: bnez a0, .LBB8_3
; RV32I-NEXT: # %bb.1: # %entry
; RV32I-NEXT: lui a0, %hi(.LCPI8_0)
; RV32I-NEXT: flw ft1, %lo(.LCPI8_0)(a0)
; RV32I-NEXT: flt.s a0, ft1, ft0
; RV32I-NEXT: bnez a0, .LBB8_3
; RV32I-NEXT: # %bb.2: # %entry
; RV32I-NEXT: fmv.s ft1, ft0
; RV32I-NEXT: .LBB8_3: # %entry
; RV32I-NEXT: fmv.x.w a0, ft1
; RV32I-NEXT: ret
;
; RV32IBT-LABEL: CascadedSelect:
; RV32IBT: # %bb.0: # %entry
; RV32IBT-NEXT: fmv.w.x ft0, a0
; RV32IBT-NEXT: fmv.w.x ft1, zero
; RV32IBT-NEXT: flt.s a0, ft0, ft1
; RV32IBT-NEXT: bnez a0, .LBB8_3
; RV32IBT-NEXT: # %bb.1: # %entry
; RV32IBT-NEXT: lui a0, %hi(.LCPI8_0)
; RV32IBT-NEXT: flw ft1, %lo(.LCPI8_0)(a0)
; RV32IBT-NEXT: flt.s a0, ft1, ft0
; RV32IBT-NEXT: bnez a0, .LBB8_3
; RV32IBT-NEXT: # %bb.2: # %entry
; RV32IBT-NEXT: fmv.s ft1, ft0
; RV32IBT-NEXT: .LBB8_3: # %entry
; RV32IBT-NEXT: fmv.x.w a0, ft1
; RV32IBT-NEXT: ret
;
; RV64I-LABEL: CascadedSelect:
; RV64I: # %bb.0: # %entry
; RV64I-NEXT: fmv.w.x ft0, a0
; RV64I-NEXT: fmv.w.x ft1, zero
; RV64I-NEXT: flt.s a0, ft0, ft1
; RV64I-NEXT: bnez a0, .LBB8_3
; RV64I-NEXT: # %bb.1: # %entry
; RV64I-NEXT: lui a0, %hi(.LCPI8_0)
; RV64I-NEXT: flw ft1, %lo(.LCPI8_0)(a0)
; RV64I-NEXT: flt.s a0, ft1, ft0
; RV64I-NEXT: bnez a0, .LBB8_3
; RV64I-NEXT: # %bb.2: # %entry
; RV64I-NEXT: fmv.s ft1, ft0
; RV64I-NEXT: .LBB8_3: # %entry
; RV64I-NEXT: fmv.x.w a0, ft1
; RV64I-NEXT: ret
;
; RV64IBT-LABEL: CascadedSelect:
; RV64IBT: # %bb.0: # %entry
; RV64IBT-NEXT: fmv.w.x ft0, a0
; RV64IBT-NEXT: fmv.w.x ft1, zero
; RV64IBT-NEXT: flt.s a0, ft0, ft1
; RV64IBT-NEXT: bnez a0, .LBB8_3
; RV64IBT-NEXT: # %bb.1: # %entry
; RV64IBT-NEXT: lui a0, %hi(.LCPI8_0)
; RV64IBT-NEXT: flw ft1, %lo(.LCPI8_0)(a0)
; RV64IBT-NEXT: flt.s a0, ft1, ft0
; RV64IBT-NEXT: bnez a0, .LBB8_3
; RV64IBT-NEXT: # %bb.2: # %entry
; RV64IBT-NEXT: fmv.s ft1, ft0
; RV64IBT-NEXT: .LBB8_3: # %entry
; RV64IBT-NEXT: fmv.x.w a0, ft1
; RV64IBT-NEXT: ret
entry:
%cmp = fcmp ogt float %a, 1.000000e+00
%cmp1 = fcmp olt float %a, 0.000000e+00
%.a = select i1 %cmp1, float 0.000000e+00, float %a
%retval.0 = select i1 %cmp, float 1.000000e+00, float %.a
ret float %retval.0
}

0 comments on commit 1178992

Please sign in to comment.