Skip to content

Commit

Permalink
[RISCV] Improve vector fround lowering by changing FRM.
Browse files Browse the repository at this point in the history
This is a follow up to D133238 which did this for ceil/floor.

Reviewed By: arcbbb, frasercrmck

Differential Revision: https://reviews.llvm.org/D133335
  • Loading branch information
topperc committed Sep 6, 2022
1 parent acb767f commit 5d30565
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 229 deletions.
137 changes: 27 additions & 110 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,11 +1836,24 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
return DAG.getSelectCC(DL, Src, Src, ZeroInt, FpToInt, ISD::CondCode::SETUO);
}

// Expand vector FTRUNC, FCEIL, and FFLOOR by converting to the integer domain
// and back. Taking care to avoid converting values that are nan or already
// correct.
static SDValue lowerFTRUNC_FCEIL_FFLOOR(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
switch (Opc) {
case ISD::FROUNDEVEN: return RISCVFPRndMode::RNE;
case ISD::FTRUNC: return RISCVFPRndMode::RTZ;
case ISD::FFLOOR: return RISCVFPRndMode::RDN;
case ISD::FCEIL: return RISCVFPRndMode::RUP;
case ISD::FROUND: return RISCVFPRndMode::RMM;
}

return RISCVFPRndMode::Invalid;
}

// Expand vector FTRUNC, FCEIL, FFLOOR, and FROUND by converting to the integer
// domain/ and back. Taking care to avoid converting values that are nan or
// already correct.
static SDValue
lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
assert(VT.isVector() && "Unexpected type");

Expand Down Expand Up @@ -1892,15 +1905,14 @@ static SDValue lowerFTRUNC_FCEIL_FFLOOR(SDValue Op, SelectionDAG &DAG,
default:
llvm_unreachable("Unexpected opcode");
case ISD::FCEIL:
Truncated =
DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask,
DAG.getTargetConstant(RISCVFPRndMode::RUP, DL, XLenVT), VL);
break;
case ISD::FFLOOR:
Truncated =
DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask,
DAG.getTargetConstant(RISCVFPRndMode::RDN, DL, XLenVT), VL);
case ISD::FROUND: {
RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Op.getOpcode());
assert(FRM != RISCVFPRndMode::Invalid);
Truncated = DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask,
DAG.getTargetConstant(FRM, DL, XLenVT), VL);
break;
}
case ISD::FTRUNC:
Truncated = DAG.getNode(RISCVISD::FP_TO_SINT_VL, DL, IntVT, Src, Mask, VL);
break;
Expand All @@ -1919,88 +1931,6 @@ static SDValue lowerFTRUNC_FCEIL_FFLOOR(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
}

// ISD::FROUND is defined to round to nearest with ties rounding away from 0.
// This mode isn't supported in vector hardware on RISCV. But as long as we
// aren't compiling with trapping math, we can emulate this with
// floor(X + copysign(nextafter(0.5, 0.0), X)).
// FIXME: Could be shorter by changing rounding mode, but we don't have FRM
// dependencies modeled yet.
static SDValue lowerFROUND(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
assert(VT.isVector() && "Unexpected type");

SDLoc DL(Op);

SDValue Src = Op.getOperand(0);

MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

SDValue TrueMask, VL;
std::tie(TrueMask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);

// Freeze the source since we are increasing the number of uses.
Src = DAG.getFreeze(Src);

// We do the conversion on the absolute value and fix the sign at the end.
SDValue Abs =
DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, TrueMask, VL);

// Determine the largest integer that can be represented exactly. This and
// values larger than it don't have any fractional bits so don't need to
// be converted.
const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
unsigned Precision = APFloat::semanticsPrecision(FltSem);
APFloat MaxVal = APFloat(FltSem);
MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
/*IsSigned*/ false, APFloat::rmNearestTiesToEven);
SDValue MaxValNode =
DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), MaxValNode, VL);

// If abs(Src) was larger than MaxVal or nan, keep it.
MVT SetccVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
SDValue Mask = DAG.getNode(RISCVISD::SETCC_VL, DL, SetccVT,
{Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT),
DAG.getUNDEF(SetccVT), TrueMask, VL});

bool Ignored;
APFloat Point5Pred = APFloat(0.5f);
Point5Pred.convert(FltSem, APFloat::rmNearestTiesToEven, &Ignored);
Point5Pred.next(/*nextDown*/ true);
SDValue SplatVal =
DAG.getConstantFP(Point5Pred, DL, ContainerVT.getVectorElementType());
SDValue Splat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), SplatVal, VL);

// Add the adjustment.
SDValue Adjust = DAG.getNode(RISCVISD::FADD_VL, DL, ContainerVT, Abs, Splat,
DAG.getUNDEF(ContainerVT), Mask, VL);

// Truncate to integer and convert back to fp.
MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
SDValue Truncated =
DAG.getNode(RISCVISD::FP_TO_SINT_VL, DL, IntVT, Adjust, Mask, VL);

Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT, Truncated,
Mask, VL);

// Restore the original sign and merge the original source to masked off
// lanes.
Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
Src, Src, Mask, VL);

if (!VT.isFixedLengthVector())
return Truncated;

return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
}

struct VIDSequence {
int64_t StepNumerator;
unsigned StepDenominator;
Expand Down Expand Up @@ -3493,9 +3423,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FTRUNC:
case ISD::FCEIL:
case ISD::FFLOOR:
return lowerFTRUNC_FCEIL_FFLOOR(Op, DAG, Subtarget);
case ISD::FROUND:
return lowerFROUND(Op, DAG, Subtarget);
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_SMAX:
Expand Down Expand Up @@ -8844,18 +8773,6 @@ static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Merge, Mask, VL);
}

static RISCVFPRndMode::RoundingMode matchRoundingOp(SDValue Op) {
switch (Op.getOpcode()) {
case ISD::FROUNDEVEN: return RISCVFPRndMode::RNE;
case ISD::FTRUNC: return RISCVFPRndMode::RTZ;
case ISD::FFLOOR: return RISCVFPRndMode::RDN;
case ISD::FCEIL: return RISCVFPRndMode::RUP;
case ISD::FROUND: return RISCVFPRndMode::RMM;
}

return RISCVFPRndMode::Invalid;
}

// Fold
// (fp_to_int (froundeven X)) -> fcvt X, rne
// (fp_to_int (ftrunc X)) -> fcvt X, rtz
Expand Down Expand Up @@ -8885,7 +8802,7 @@ static SDValue performFP_TO_INTCombine(SDNode *N,
if (Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfh())
return SDValue();

RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src);
RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
if (FRM == RISCVFPRndMode::Invalid)
return SDValue();

Expand Down Expand Up @@ -8934,7 +8851,7 @@ static SDValue performFP_TO_INT_SATCombine(SDNode *N,

EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();

RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src);
RISCVFPRndMode::RoundingMode FRM = matchRoundingOp(Src.getOpcode());
if (FRM == RISCVFPRndMode::Invalid)
return SDValue();

Expand Down
32 changes: 16 additions & 16 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,22 +312,22 @@ static const CostTblEntry VectorIntrinsicCostTable[]{
{Intrinsic::trunc, MVT::nxv2f64, 7},
{Intrinsic::trunc, MVT::nxv4f64, 7},
{Intrinsic::trunc, MVT::nxv8f64, 7},
{Intrinsic::round, MVT::v2f32, 10},
{Intrinsic::round, MVT::v4f32, 10},
{Intrinsic::round, MVT::v8f32, 10},
{Intrinsic::round, MVT::v16f32, 10},
{Intrinsic::round, MVT::nxv2f32, 10},
{Intrinsic::round, MVT::nxv4f32, 10},
{Intrinsic::round, MVT::nxv8f32, 10},
{Intrinsic::round, MVT::nxv16f32, 10},
{Intrinsic::round, MVT::v2f64, 10},
{Intrinsic::round, MVT::v4f64, 10},
{Intrinsic::round, MVT::v8f64, 10},
{Intrinsic::round, MVT::v16f64, 10},
{Intrinsic::round, MVT::nxv1f64, 10},
{Intrinsic::round, MVT::nxv2f64, 10},
{Intrinsic::round, MVT::nxv4f64, 10},
{Intrinsic::round, MVT::nxv8f64, 10},
{Intrinsic::round, MVT::v2f32, 9},
{Intrinsic::round, MVT::v4f32, 9},
{Intrinsic::round, MVT::v8f32, 9},
{Intrinsic::round, MVT::v16f32, 9},
{Intrinsic::round, MVT::nxv2f32, 9},
{Intrinsic::round, MVT::nxv4f32, 9},
{Intrinsic::round, MVT::nxv8f32, 9},
{Intrinsic::round, MVT::nxv16f32, 9},
{Intrinsic::round, MVT::v2f64, 9},
{Intrinsic::round, MVT::v4f64, 9},
{Intrinsic::round, MVT::v8f64, 9},
{Intrinsic::round, MVT::v16f64, 9},
{Intrinsic::round, MVT::nxv1f64, 9},
{Intrinsic::round, MVT::nxv2f64, 9},
{Intrinsic::round, MVT::nxv4f64, 9},
{Intrinsic::round, MVT::nxv8f64, 9},
{Intrinsic::fabs, MVT::v2f32, 1},
{Intrinsic::fabs, MVT::v4f32, 1},
{Intrinsic::fabs, MVT::v8f32, 1},
Expand Down
32 changes: 16 additions & 16 deletions llvm/test/Analysis/CostModel/RISCV/fround.ll
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,23 @@ define void @nearbyint() {
define void @round() {
; CHECK-LABEL: 'round'
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %1 = call float @llvm.round.f32(float undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %2 = call <2 x float> @llvm.round.v2f32(<2 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %3 = call <4 x float> @llvm.round.v4f32(<4 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %4 = call <8 x float> @llvm.round.v8f32(<8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %5 = call <16 x float> @llvm.round.v16f32(<16 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %6 = call <vscale x 2 x float> @llvm.round.nxv2f32(<vscale x 2 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %7 = call <vscale x 4 x float> @llvm.round.nxv4f32(<vscale x 4 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %8 = call <vscale x 8 x float> @llvm.round.nxv8f32(<vscale x 8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %9 = call <vscale x 16 x float> @llvm.round.nxv16f32(<vscale x 16 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %2 = call <2 x float> @llvm.round.v2f32(<2 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %3 = call <4 x float> @llvm.round.v4f32(<4 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %4 = call <8 x float> @llvm.round.v8f32(<8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %5 = call <16 x float> @llvm.round.v16f32(<16 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %6 = call <vscale x 2 x float> @llvm.round.nxv2f32(<vscale x 2 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %7 = call <vscale x 4 x float> @llvm.round.nxv4f32(<vscale x 4 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %8 = call <vscale x 8 x float> @llvm.round.nxv8f32(<vscale x 8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %9 = call <vscale x 16 x float> @llvm.round.nxv16f32(<vscale x 16 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %10 = call double @llvm.round.f64(double undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %11 = call <2 x double> @llvm.round.v2f64(<2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %12 = call <4 x double> @llvm.round.v4f64(<4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %13 = call <8 x double> @llvm.round.v8f64(<8 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %14 = call <16 x double> @llvm.round.v16f64(<16 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %15 = call <vscale x 1 x double> @llvm.round.nxv1f64(<vscale x 1 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %16 = call <vscale x 2 x double> @llvm.round.nxv2f64(<vscale x 2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %17 = call <vscale x 4 x double> @llvm.round.nxv4f64(<vscale x 4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %18 = call <vscale x 8 x double> @llvm.round.nxv8f64(<vscale x 8 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %11 = call <2 x double> @llvm.round.v2f64(<2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %12 = call <4 x double> @llvm.round.v4f64(<4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %13 = call <8 x double> @llvm.round.v8f64(<8 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %14 = call <16 x double> @llvm.round.v16f64(<16 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %15 = call <vscale x 1 x double> @llvm.round.nxv1f64(<vscale x 1 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %16 = call <vscale x 2 x double> @llvm.round.nxv2f64(<vscale x 2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %17 = call <vscale x 4 x double> @llvm.round.nxv4f64(<vscale x 4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %18 = call <vscale x 8 x double> @llvm.round.nxv8f64(<vscale x 8 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 1 for instruction: ret void
;
call float @llvm.round.f32(float undef)
Expand Down
21 changes: 9 additions & 12 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2171,12 +2171,11 @@ define void @round_v8f16(<8 x half>* %x) {
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: lui a1, %hi(.LCPI100_0)
; CHECK-NEXT: flh ft0, %lo(.LCPI100_0)(a1)
; CHECK-NEXT: lui a1, %hi(.LCPI100_1)
; CHECK-NEXT: flh ft1, %lo(.LCPI100_1)(a1)
; CHECK-NEXT: vfabs.v v9, v8
; CHECK-NEXT: vmflt.vf v0, v9, ft0
; CHECK-NEXT: vfadd.vf v9, v9, ft1, v0.t
; CHECK-NEXT: vfcvt.rtz.x.f.v v9, v9, v0.t
; CHECK-NEXT: fsrmi a1, 4
; CHECK-NEXT: vfcvt.x.f.v v9, v8, v0.t
; CHECK-NEXT: fsrm a1
; CHECK-NEXT: vfcvt.f.x.v v9, v9, v0.t
; CHECK-NEXT: vfsgnj.vv v8, v9, v8, v0.t
; CHECK-NEXT: vse16.v v8, (a0)
Expand All @@ -2195,12 +2194,11 @@ define void @round_v4f32(<4 x float>* %x) {
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: lui a1, %hi(.LCPI101_0)
; CHECK-NEXT: flw ft0, %lo(.LCPI101_0)(a1)
; CHECK-NEXT: lui a1, %hi(.LCPI101_1)
; CHECK-NEXT: flw ft1, %lo(.LCPI101_1)(a1)
; CHECK-NEXT: vfabs.v v9, v8
; CHECK-NEXT: vmflt.vf v0, v9, ft0
; CHECK-NEXT: vfadd.vf v9, v9, ft1, v0.t
; CHECK-NEXT: vfcvt.rtz.x.f.v v9, v9, v0.t
; CHECK-NEXT: fsrmi a1, 4
; CHECK-NEXT: vfcvt.x.f.v v9, v8, v0.t
; CHECK-NEXT: fsrm a1
; CHECK-NEXT: vfcvt.f.x.v v9, v9, v0.t
; CHECK-NEXT: vfsgnj.vv v8, v9, v8, v0.t
; CHECK-NEXT: vse32.v v8, (a0)
Expand All @@ -2219,12 +2217,11 @@ define void @round_v2f64(<2 x double>* %x) {
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: lui a1, %hi(.LCPI102_0)
; CHECK-NEXT: fld ft0, %lo(.LCPI102_0)(a1)
; CHECK-NEXT: lui a1, %hi(.LCPI102_1)
; CHECK-NEXT: fld ft1, %lo(.LCPI102_1)(a1)
; CHECK-NEXT: vfabs.v v9, v8
; CHECK-NEXT: vmflt.vf v0, v9, ft0
; CHECK-NEXT: vfadd.vf v9, v9, ft1, v0.t
; CHECK-NEXT: vfcvt.rtz.x.f.v v9, v9, v0.t
; CHECK-NEXT: fsrmi a1, 4
; CHECK-NEXT: vfcvt.x.f.v v9, v8, v0.t
; CHECK-NEXT: fsrm a1
; CHECK-NEXT: vfcvt.f.x.v v9, v9, v0.t
; CHECK-NEXT: vfsgnj.vv v8, v9, v8, v0.t
; CHECK-NEXT: vse64.v v8, (a0)
Expand Down
Loading

0 comments on commit 5d30565

Please sign in to comment.