Skip to content

Commit

Permalink
[RISCV] Support FP_TO_S/UINT_SAT for i32 and i64.
Browse files Browse the repository at this point in the history
The fcvt fp to integer instructions saturate if their input is
infinity or out of range, but the instructions produce a maximum
integer for nan instead of 0 required for the ISD opcodes.

This means we can use the instructions to do the saturating
conversion, but we'll need to fix up the nan case at the end.

We can probably improve the i8 and i16 default codegen as well,
but I'll leave that for a follow up.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D107230
  • Loading branch information
topperc committed Aug 7, 2021
1 parent 47a889c commit d4ee84c
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 402 deletions.
34 changes: 34 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -376,6 +376,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}

if (Subtarget.hasStdExtF()) {
setOperationAction(ISD::FP_TO_UINT_SAT, XLenVT, Custom);
setOperationAction(ISD::FP_TO_SINT_SAT, XLenVT, Custom);

setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom);
setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
}
Expand Down Expand Up @@ -1379,6 +1382,32 @@ bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
return false;
}

static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) {
// RISCV FP-to-int conversions saturate to the destination register size, but
// don't produce 0 for nan. We can use a conversion instruction and fix the
// nan case with a compare and a select.
SDValue Src = Op.getOperand(0);

EVT DstVT = Op.getValueType();
EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();

bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT;
unsigned Opc;
if (SatVT == DstVT)
Opc = IsSigned ? RISCVISD::FCVT_X_RTZ : RISCVISD::FCVT_XU_RTZ;
else if (DstVT == MVT::i64 && SatVT == MVT::i32)
Opc = IsSigned ? RISCVISD::FCVT_W_RTZ_RV64 : RISCVISD::FCVT_WU_RTZ_RV64;
else
return SDValue();
// FIXME: Support other SatVTs by clamping before or after the conversion.

SDLoc DL(Op);
SDValue FpToInt = DAG.getNode(Opc, DL, DstVT, Src);

SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
}

static SDValue lowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -2517,6 +2546,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
Src = DAG.getNode(RVVOpc, DL, ContainerVT, Src, Mask, VL);
return convertFromScalableVector(VT, Src, DAG, Subtarget);
}
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return lowerFP_TO_INT_SAT(Op, DAG);
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_SMAX:
Expand Down Expand Up @@ -8385,6 +8417,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FMV_X_ANYEXTH)
NODE_NAME_CASE(FMV_W_X_RV64)
NODE_NAME_CASE(FMV_X_ANYEXTW_RV64)
NODE_NAME_CASE(FCVT_X_RTZ)
NODE_NAME_CASE(FCVT_XU_RTZ)
NODE_NAME_CASE(FCVT_W_RTZ_RV64)
NODE_NAME_CASE(FCVT_WU_RTZ_RV64)
NODE_NAME_CASE(READ_CYCLE_WIDE)
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -84,8 +84,14 @@ enum NodeType : unsigned {
FMV_X_ANYEXTH,
FMV_W_X_RV64,
FMV_X_ANYEXTW_RV64,
// FP to XLen int conversions. Corresponds to fcvt.l(u).s/d/h on RV64 and
// fcvt.w(u).s/d/h on RV32. Unlike FP_TO_S/UINT these saturate out of
// range inputs. These are used for FP_TO_S/UINT_SAT lowering.
FCVT_X_RTZ,
FCVT_XU_RTZ,
// FP to 32 bit int conversions for RV64. These are used to keep track of the
// result being sign extended to 64 bit.
// result being sign extended to 64 bit. These saturate out of range inputs.
// Used for FP_TO_S/UINT and FP_TO_S/UINT_SAT lowering.
FCVT_W_RTZ_RV64,
FCVT_WU_RTZ_RV64,
// READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoD.td
Expand Up @@ -331,6 +331,10 @@ def : Pat<(f64 (fpimm0)), (FCVT_D_W (i32 X0))>;
def : Pat<(i32 (fp_to_sint FPR64:$rs1)), (FCVT_W_D FPR64:$rs1, 0b001)>;
def : Pat<(i32 (fp_to_uint FPR64:$rs1)), (FCVT_WU_D FPR64:$rs1, 0b001)>;

// Saturating double->[u]int32.
def : Pat<(i32 (riscv_fcvt_x_rtz FPR64:$rs1)), (FCVT_W_D $rs1, 0b001)>;
def : Pat<(i32 (riscv_fcvt_xu_rtz FPR64:$rs1)), (FCVT_WU_D $rs1, 0b001)>;

// float->int32 with current rounding mode.
def : Pat<(i32 (lrint FPR64:$rs1)), (FCVT_W_D $rs1, 0b111)>;

Expand Down Expand Up @@ -361,6 +365,10 @@ def : Pat<(riscv_fcvt_wu_rtz_rv64 FPR64:$rs1), (FCVT_WU_D $rs1, 0b001)>;
def : Pat<(sint_to_fp (i64 (sexti32 (i64 GPR:$rs1)))), (FCVT_D_W $rs1)>;
def : Pat<(uint_to_fp (i64 (zexti32 (i64 GPR:$rs1)))), (FCVT_D_WU $rs1)>;

// Saturating double->[u]int64.
def : Pat<(i64 (riscv_fcvt_x_rtz FPR64:$rs1)), (FCVT_L_D $rs1, 0b001)>;
def : Pat<(i64 (riscv_fcvt_xu_rtz FPR64:$rs1)), (FCVT_LU_D $rs1, 0b001)>;

// double->[u]int64. Round-to-zero must be used.
def : Pat<(i64 (fp_to_sint FPR64:$rs1)), (FCVT_L_D FPR64:$rs1, 0b001)>;
def : Pat<(i64 (fp_to_uint FPR64:$rs1)), (FCVT_LU_D FPR64:$rs1, 0b001)>;
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoF.td
Expand Up @@ -21,6 +21,8 @@ def SDT_RISCVFMV_X_ANYEXTW_RV64
: SDTypeProfile<1, 1, [SDTCisVT<0, i64>, SDTCisVT<1, f32>]>;
def STD_RISCVFCVT_W_RV64
: SDTypeProfile<1, 1, [SDTCisVT<0, i64>, SDTCisFP<1>]>;
def STD_RISCVFCVT_X
: SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>;

def riscv_fmv_w_x_rv64
: SDNode<"RISCVISD::FMV_W_X_RV64", SDT_RISCVFMV_W_X_RV64>;
Expand All @@ -30,6 +32,10 @@ def riscv_fcvt_w_rtz_rv64
: SDNode<"RISCVISD::FCVT_W_RTZ_RV64", STD_RISCVFCVT_W_RV64>;
def riscv_fcvt_wu_rtz_rv64
: SDNode<"RISCVISD::FCVT_WU_RTZ_RV64", STD_RISCVFCVT_W_RV64>;
def riscv_fcvt_x_rtz
: SDNode<"RISCVISD::FCVT_X_RTZ", STD_RISCVFCVT_X>;
def riscv_fcvt_xu_rtz
: SDNode<"RISCVISD::FCVT_XU_RTZ", STD_RISCVFCVT_X>;

//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
Expand Down Expand Up @@ -379,6 +385,10 @@ def : Pat<(i32 (bitconvert FPR32:$rs1)), (FMV_X_W FPR32:$rs1)>;
def : Pat<(i32 (fp_to_sint FPR32:$rs1)), (FCVT_W_S $rs1, 0b001)>;
def : Pat<(i32 (fp_to_uint FPR32:$rs1)), (FCVT_WU_S $rs1, 0b001)>;

// Saturating float->[u]int32.
def : Pat<(i32 (riscv_fcvt_x_rtz FPR32:$rs1)), (FCVT_W_S $rs1, 0b001)>;
def : Pat<(i32 (riscv_fcvt_xu_rtz FPR32:$rs1)), (FCVT_WU_S $rs1, 0b001)>;

// float->int32 with current rounding mode.
def : Pat<(i32 (lrint FPR32:$rs1)), (FCVT_W_S $rs1, 0b111)>;

Expand Down Expand Up @@ -407,6 +417,10 @@ def : Pat<(riscv_fcvt_wu_rtz_rv64 FPR32:$rs1), (FCVT_WU_S $rs1, 0b001)>;
def : Pat<(i64 (fp_to_sint FPR32:$rs1)), (FCVT_L_S $rs1, 0b001)>;
def : Pat<(i64 (fp_to_uint FPR32:$rs1)), (FCVT_LU_S $rs1, 0b001)>;

// Saturating float->[u]int64.
def : Pat<(i64 (riscv_fcvt_x_rtz FPR32:$rs1)), (FCVT_L_S $rs1, 0b001)>;
def : Pat<(i64 (riscv_fcvt_xu_rtz FPR32:$rs1)), (FCVT_LU_S $rs1, 0b001)>;

// float->int64 with current rounding mode.
def : Pat<(i64 (lrint FPR32:$rs1)), (FCVT_L_S $rs1, 0b111)>;
def : Pat<(i64 (llrint FPR32:$rs1)), (FCVT_L_S $rs1, 0b111)>;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
Expand Up @@ -338,6 +338,10 @@ let Predicates = [HasStdExtZfh, IsRV32] in {
def : Pat<(i32 (fp_to_sint FPR16:$rs1)), (FCVT_W_H $rs1, 0b001)>;
def : Pat<(i32 (fp_to_uint FPR16:$rs1)), (FCVT_WU_H $rs1, 0b001)>;

// Saturating float->[u]int32.
def : Pat<(i32 (riscv_fcvt_x_rtz FPR16:$rs1)), (FCVT_W_H $rs1, 0b001)>;
def : Pat<(i32 (riscv_fcvt_xu_rtz FPR16:$rs1)), (FCVT_WU_H $rs1, 0b001)>;

// half->int32 with current rounding mode.
def : Pat<(i32 (lrint FPR16:$rs1)), (FCVT_W_H $rs1, 0b111)>;

Expand All @@ -360,6 +364,10 @@ def : Pat<(riscv_fcvt_wu_rtz_rv64 FPR16:$rs1), (FCVT_WU_H $rs1, 0b001)>;
def : Pat<(i64 (fp_to_sint FPR16:$rs1)), (FCVT_L_H $rs1, 0b001)>;
def : Pat<(i64 (fp_to_uint FPR16:$rs1)), (FCVT_LU_H $rs1, 0b001)>;

// Saturating float->[u]int64.
def : Pat<(i64 (riscv_fcvt_x_rtz FPR16:$rs1)), (FCVT_L_H $rs1, 0b001)>;
def : Pat<(i64 (riscv_fcvt_xu_rtz FPR16:$rs1)), (FCVT_LU_H $rs1, 0b001)>;

// half->int64 with current rounding mode.
def : Pat<(i64 (lrint FPR16:$rs1)), (FCVT_L_H $rs1, 0b111)>;
def : Pat<(i64 (llrint FPR16:$rs1)), (FCVT_L_H $rs1, 0b111)>;
Expand Down
84 changes: 23 additions & 61 deletions llvm/test/CodeGen/RISCV/double-convert.ll
Expand Up @@ -84,12 +84,6 @@ define i32 @fcvt_w_d_sat(double %a) nounwind {
; RV32IFD-NEXT: addi sp, sp, 16
; RV32IFD-NEXT: ret
; RV32IFD-NEXT: .LBB3_2:
; RV32IFD-NEXT: lui a0, %hi(.LCPI3_0)
; RV32IFD-NEXT: fld ft1, %lo(.LCPI3_0)(a0)
; RV32IFD-NEXT: lui a0, %hi(.LCPI3_1)
; RV32IFD-NEXT: fld ft2, %lo(.LCPI3_1)(a0)
; RV32IFD-NEXT: fmax.d ft0, ft0, ft1
; RV32IFD-NEXT: fmin.d ft0, ft0, ft2
; RV32IFD-NEXT: fcvt.w.d a0, ft0, rtz
; RV32IFD-NEXT: addi sp, sp, 16
; RV32IFD-NEXT: ret
Expand All @@ -103,13 +97,7 @@ define i32 @fcvt_w_d_sat(double %a) nounwind {
; RV64IFD-NEXT: mv a0, zero
; RV64IFD-NEXT: ret
; RV64IFD-NEXT: .LBB3_2:
; RV64IFD-NEXT: lui a0, %hi(.LCPI3_0)
; RV64IFD-NEXT: fld ft1, %lo(.LCPI3_0)(a0)
; RV64IFD-NEXT: lui a0, %hi(.LCPI3_1)
; RV64IFD-NEXT: fld ft2, %lo(.LCPI3_1)(a0)
; RV64IFD-NEXT: fmax.d ft0, ft0, ft1
; RV64IFD-NEXT: fmin.d ft0, ft0, ft2
; RV64IFD-NEXT: fcvt.l.d a0, ft0, rtz
; RV64IFD-NEXT: fcvt.w.d a0, ft0, rtz
; RV64IFD-NEXT: ret
start:
%0 = tail call i32 @llvm.fptosi.sat.i32.f64(double %a)
Expand Down Expand Up @@ -182,24 +170,27 @@ define i32 @fcvt_wu_d_sat(double %a) nounwind {
; RV32IFD-NEXT: sw a0, 8(sp)
; RV32IFD-NEXT: sw a1, 12(sp)
; RV32IFD-NEXT: fld ft0, 8(sp)
; RV32IFD-NEXT: lui a0, %hi(.LCPI6_0)
; RV32IFD-NEXT: fld ft1, %lo(.LCPI6_0)(a0)
; RV32IFD-NEXT: fcvt.d.w ft2, zero
; RV32IFD-NEXT: fmax.d ft0, ft0, ft2
; RV32IFD-NEXT: fmin.d ft0, ft0, ft1
; RV32IFD-NEXT: feq.d a0, ft0, ft0
; RV32IFD-NEXT: bnez a0, .LBB6_2
; RV32IFD-NEXT: # %bb.1: # %start
; RV32IFD-NEXT: mv a0, zero
; RV32IFD-NEXT: addi sp, sp, 16
; RV32IFD-NEXT: ret
; RV32IFD-NEXT: .LBB6_2:
; RV32IFD-NEXT: fcvt.wu.d a0, ft0, rtz
; RV32IFD-NEXT: addi sp, sp, 16
; RV32IFD-NEXT: ret
;
; RV64IFD-LABEL: fcvt_wu_d_sat:
; RV64IFD: # %bb.0: # %start
; RV64IFD-NEXT: lui a1, %hi(.LCPI6_0)
; RV64IFD-NEXT: fld ft0, %lo(.LCPI6_0)(a1)
; RV64IFD-NEXT: fmv.d.x ft1, a0
; RV64IFD-NEXT: fmv.d.x ft2, zero
; RV64IFD-NEXT: fmax.d ft1, ft1, ft2
; RV64IFD-NEXT: fmin.d ft0, ft1, ft0
; RV64IFD-NEXT: fcvt.lu.d a0, ft0, rtz
; RV64IFD-NEXT: fmv.d.x ft0, a0
; RV64IFD-NEXT: feq.d a0, ft0, ft0
; RV64IFD-NEXT: bnez a0, .LBB6_2
; RV64IFD-NEXT: # %bb.1: # %start
; RV64IFD-NEXT: mv a0, zero
; RV64IFD-NEXT: ret
; RV64IFD-NEXT: .LBB6_2:
; RV64IFD-NEXT: fcvt.wu.d a0, ft0, rtz
; RV64IFD-NEXT: ret
start:
%0 = tail call i32 @llvm.fptoui.sat.i32.f64(double %a)
Expand Down Expand Up @@ -370,33 +361,14 @@ define i64 @fcvt_l_d_sat(double %a) nounwind {
;
; RV64IFD-LABEL: fcvt_l_d_sat:
; RV64IFD: # %bb.0: # %start
; RV64IFD-NEXT: lui a1, %hi(.LCPI12_0)
; RV64IFD-NEXT: fld ft1, %lo(.LCPI12_0)(a1)
; RV64IFD-NEXT: fmv.d.x ft0, a0
; RV64IFD-NEXT: fle.d a0, ft1, ft0
; RV64IFD-NEXT: addi a1, zero, -1
; RV64IFD-NEXT: feq.d a0, ft0, ft0
; RV64IFD-NEXT: bnez a0, .LBB12_2
; RV64IFD-NEXT: # %bb.1: # %start
; RV64IFD-NEXT: slli a0, a1, 63
; RV64IFD-NEXT: j .LBB12_3
; RV64IFD-NEXT: mv a0, zero
; RV64IFD-NEXT: ret
; RV64IFD-NEXT: .LBB12_2:
; RV64IFD-NEXT: fcvt.l.d a0, ft0, rtz
; RV64IFD-NEXT: .LBB12_3: # %start
; RV64IFD-NEXT: lui a2, %hi(.LCPI12_1)
; RV64IFD-NEXT: fld ft1, %lo(.LCPI12_1)(a2)
; RV64IFD-NEXT: flt.d a2, ft1, ft0
; RV64IFD-NEXT: bnez a2, .LBB12_6
; RV64IFD-NEXT: # %bb.4: # %start
; RV64IFD-NEXT: feq.d a1, ft0, ft0
; RV64IFD-NEXT: beqz a1, .LBB12_7
; RV64IFD-NEXT: .LBB12_5: # %start
; RV64IFD-NEXT: ret
; RV64IFD-NEXT: .LBB12_6:
; RV64IFD-NEXT: srli a0, a1, 1
; RV64IFD-NEXT: feq.d a1, ft0, ft0
; RV64IFD-NEXT: bnez a1, .LBB12_5
; RV64IFD-NEXT: .LBB12_7: # %start
; RV64IFD-NEXT: mv a0, zero
; RV64IFD-NEXT: ret
start:
%0 = tail call i64 @llvm.fptosi.sat.i64.f64(double %a)
Expand Down Expand Up @@ -469,23 +441,13 @@ define i64 @fcvt_lu_d_sat(double %a) nounwind {
; RV64IFD-LABEL: fcvt_lu_d_sat:
; RV64IFD: # %bb.0: # %start
; RV64IFD-NEXT: fmv.d.x ft0, a0
; RV64IFD-NEXT: fmv.d.x ft1, zero
; RV64IFD-NEXT: fle.d a0, ft1, ft0
; RV64IFD-NEXT: feq.d a0, ft0, ft0
; RV64IFD-NEXT: bnez a0, .LBB14_2
; RV64IFD-NEXT: # %bb.1: # %start
; RV64IFD-NEXT: mv a1, zero
; RV64IFD-NEXT: j .LBB14_3
; RV64IFD-NEXT: mv a0, zero
; RV64IFD-NEXT: ret
; RV64IFD-NEXT: .LBB14_2:
; RV64IFD-NEXT: fcvt.lu.d a1, ft0, rtz
; RV64IFD-NEXT: .LBB14_3: # %start
; RV64IFD-NEXT: lui a0, %hi(.LCPI14_0)
; RV64IFD-NEXT: fld ft1, %lo(.LCPI14_0)(a0)
; RV64IFD-NEXT: flt.d a2, ft1, ft0
; RV64IFD-NEXT: addi a0, zero, -1
; RV64IFD-NEXT: bnez a2, .LBB14_5
; RV64IFD-NEXT: # %bb.4: # %start
; RV64IFD-NEXT: mv a0, a1
; RV64IFD-NEXT: .LBB14_5: # %start
; RV64IFD-NEXT: fcvt.lu.d a0, ft0, rtz
; RV64IFD-NEXT: ret
start:
%0 = tail call i64 @llvm.fptoui.sat.i64.f64(double %a)
Expand Down

0 comments on commit d4ee84c

Please sign in to comment.