Skip to content

Commit

Permalink
[RISCV] Add DAG combine to fold (fp_to_int_sat (ffloor X)) -> (select…
Browse files Browse the repository at this point in the history
… X == nan, 0, (fcvt X, rdn))

Similar for ceil, trunc, round, and roundeven. This allows us to use
static rounding modes to avoid a libcall.

This is similar to D116771, but for the saturating conversions.

This optimization is done for AArch64 as isel patterns.
RISCV doesn't have instructions for ceil/floor/trunc/round/roundeven
so the operations don't stick around until isel to enable a pattern
match. Thus I've implemented a DAG combine.

I'm only handling saturating to i64 or i32. This could be extended
to other sizes in the future.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D116864
  • Loading branch information
topperc committed Jan 20, 2022
1 parent 6b92bb4 commit 94e69fb
Show file tree
Hide file tree
Showing 4 changed files with 2,927 additions and 10 deletions.
87 changes: 77 additions & 10 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::ZERO_EXTEND);
setTargetDAGCombine(ISD::FP_TO_SINT);
setTargetDAGCombine(ISD::FP_TO_UINT);
setTargetDAGCombine(ISD::FP_TO_SINT_SAT);
setTargetDAGCombine(ISD::FP_TO_UINT_SAT);
}
if (Subtarget.hasVInstructions()) {
setTargetDAGCombine(ISD::FCOPYSIGN);
Expand Down Expand Up @@ -7180,13 +7182,24 @@ static SDValue combineMUL_VLToVWMUL(SDNode *N, SDValue Op0, SDValue Op1,
return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
}

static RISCVFPRndMode::RoundingMode matchRoundingOp(SDValue Op) {
switch (Op.getOpcode()) {
case ISD::FROUNDEVEN: return RISCVFPRndMode::RNE;
case ISD::FTRUNC: return RISCVFPRndMode::RTZ;
case ISD::FFLOOR: return RISCVFPRndMode::RDN;
case ISD::FCEIL: return RISCVFPRndMode::RUP;
case ISD::FROUND: return RISCVFPRndMode::RMM;
}

return RISCVFPRndMode::Invalid;
}

// Fold
// (fp_to_int (froundeven X)) -> fcvt X, rne
// (fp_to_int (ftrunc X)) -> fcvt X, rtz
// (fp_to_int (ffloor X)) -> fcvt X, rdn
// (fp_to_int (fceil X)) -> fcvt X, rup
// (fp_to_int (fround X)) -> fcvt X, rmm
// FIXME: We should also do this for fp_to_int_sat.
static SDValue performFP_TO_INTCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
Expand All @@ -7210,16 +7223,9 @@ static SDValue performFP_TO_INTCombine(SDNode *N,
if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
return SDValue();

RISCVFPRndMode::RoundingMode FRM;
switch (Src->getOpcode()) {
default:
RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src);
if (FRM == RISCVFPRndMode::Invalid)
return SDValue();
case ISD::FROUNDEVEN: FRM = RISCVFPRndMode::RNE; break;
case ISD::FTRUNC: FRM = RISCVFPRndMode::RTZ; break;
case ISD::FFLOOR: FRM = RISCVFPRndMode::RDN; break;
case ISD::FCEIL: FRM = RISCVFPRndMode::RUP; break;
case ISD::FROUND: FRM = RISCVFPRndMode::RMM; break;
}

bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT;

Expand All @@ -7235,6 +7241,64 @@ static SDValue performFP_TO_INTCombine(SDNode *N,
return DAG.getNode(ISD::TRUNCATE, DL, VT, FpToInt);
}

// Fold
// (fp_to_int_sat (froundeven X)) -> (select X == nan, 0, (fcvt X, rne))
// (fp_to_int_sat (ftrunc X)) -> (select X == nan, 0, (fcvt X, rtz))
// (fp_to_int_sat (ffloor X)) -> (select X == nan, 0, (fcvt X, rdn))
// (fp_to_int_sat (fceil X)) -> (select X == nan, 0, (fcvt X, rup))
// (fp_to_int_sat (fround X)) -> (select X == nan, 0, (fcvt X, rmm))
static SDValue performFP_TO_INT_SATCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
MVT XLenVT = Subtarget.getXLenVT();

// Only handle XLen types. Other types narrower than XLen will eventually be
// legalized to XLenVT.
EVT DstVT = N->getValueType(0);
if (DstVT != XLenVT)
return SDValue();

SDValue Src = N->getOperand(0);

// Ensure the FP type is also legal.
if (!TLI.isTypeLegal(Src.getValueType()))
return SDValue();

// Don't do this for f16 with Zfhmin and not Zfh.
if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
return SDValue();

EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();

RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src);
if (FRM == RISCVFPRndMode::Invalid)
return SDValue();

bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT_SAT;

unsigned Opc;
if (SatVT == DstVT)
Opc = IsSigned ? RISCVISD::FCVT_X : RISCVISD::FCVT_XU;
else if (DstVT == MVT::i64 && SatVT == MVT::i32)
Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;
else
return SDValue();
// FIXME: Support other SatVTs by clamping before or after the conversion.

Src = Src.getOperand(0);

SDLoc DL(N);
SDValue FpToInt = DAG.getNode(Opc, DL, XLenVT, Src,
DAG.getTargetConstant(FRM, DL, XLenVT));

// RISCV FP-to-int conversions saturate to the destination register size, but
// don't produce 0 for nan.
SDValue ZeroInt = DAG.getConstant(0, DL, DstVT);
return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
}

SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand Down Expand Up @@ -7548,6 +7612,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
return performFP_TO_INTCombine(N, DCI, Subtarget);
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return performFP_TO_INT_SATCombine(N, DCI, Subtarget);
case ISD::FCOPYSIGN: {
EVT VT = N->getValueType(0);
if (!VT.isVector())
Expand Down
Loading

0 comments on commit 94e69fb

Please sign in to comment.