Skip to content

Commit

Permalink
[RISCV] Lower vector CTLZ_ZERO_UNDEF/CTTZ_ZERO_UNDEF by converting to…
Browse files Browse the repository at this point in the history
… FP and extracting the exponent.

If we have a large enough floating point type that can exactly
represent the integer value, we can convert the value to FP and
use the exponent to calculate the leading/trailing zeros.

The exponent will contain log2 of the value plus the exponent bias.
We can then remove the bias and convert from log2 to leading/trailing
zeros.

This doesn't work for zero since the exponent of zero is zero so we
can only do this for CTLZ_ZERO_UNDEF/CTTZ_ZERO_UNDEF. If we need
a value for zero we can use a vmseq and a vmerge to handle it.

We need to be careful to make sure the floating point type is legal.
If it isn't we'll continue using the integer expansion. We could split the vector
and concatenate the results but that needs some additional work and evaluation.

Differential Revision: https://reviews.llvm.org/D111904
  • Loading branch information
topperc committed Nov 17, 2021
1 parent 103cc91 commit 0274be2
Show file tree
Hide file tree
Showing 6 changed files with 6,944 additions and 3,667 deletions.
8 changes: 4 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Expand Up @@ -7083,8 +7083,8 @@ SDValue TargetLowering::expandCTLZ(SDNode *Node, SelectionDAG &DAG) const {
SDValue CTLZ = DAG.getNode(ISD::CTLZ_ZERO_UNDEF, dl, VT, Op);
SDValue Zero = DAG.getConstant(0, dl, VT);
SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
return DAG.getNode(ISD::SELECT, dl, VT, SrcIsZero,
DAG.getConstant(NumBitsPerElt, dl, VT), CTLZ);
return DAG.getSelect(dl, VT, SrcIsZero,
DAG.getConstant(NumBitsPerElt, dl, VT), CTLZ);
}

// Only expand vector types if we have the appropriate vector bit operations.
Expand Down Expand Up @@ -7132,8 +7132,8 @@ SDValue TargetLowering::expandCTTZ(SDNode *Node, SelectionDAG &DAG) const {
SDValue CTTZ = DAG.getNode(ISD::CTTZ_ZERO_UNDEF, dl, VT, Op);
SDValue Zero = DAG.getConstant(0, dl, VT);
SDValue SrcIsZero = DAG.getSetCC(dl, SetCCVT, Op, Zero, ISD::SETEQ);
return DAG.getNode(ISD::SELECT, dl, VT, SrcIsZero,
DAG.getConstant(NumBitsPerElt, dl, VT), CTTZ);
return DAG.getSelect(dl, VT, SrcIsZero,
DAG.getConstant(NumBitsPerElt, dl, VT), CTTZ);
}

// Only expand vector types if we have the appropriate vector bit operations.
Expand Down
79 changes: 79 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -630,6 +630,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::SEXTLOAD, OtherVT, VT, Expand);
setLoadExtAction(ISD::ZEXTLOAD, OtherVT, VT, Expand);
}

// Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if we have a floating point
// type that can represent the value exactly.
if (VT.getVectorElementType() != MVT::i64) {
MVT FloatEltVT =
VT.getVectorElementType() == MVT::i32 ? MVT::f64 : MVT::f32;
EVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount());
if (isTypeLegal(FloatVT)) {
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Custom);
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom);
}
}
}

// Expand various CCs to best match the RVV ISA, which natively supports UNE
Expand Down Expand Up @@ -848,6 +860,19 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

for (unsigned VPOpc : IntegerVPOps)
setOperationAction(VPOpc, VT, Custom);

// Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if we have a floating point
// type that can represent the value exactly.
if (VT.getVectorElementType() != MVT::i64) {
MVT FloatEltVT =
VT.getVectorElementType() == MVT::i32 ? MVT::f64 : MVT::f32;
EVT FloatVT =
MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount());
if (isTypeLegal(FloatVT)) {
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Custom);
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Custom);
}
}
}

for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
Expand Down Expand Up @@ -2323,6 +2348,57 @@ static SDValue getRVVFPExtendOrRound(SDValue Op, MVT VT, MVT ContainerVT,
return DAG.getNode(RVVOpc, DL, ContainerVT, Op, Mask, VL);
}

// Lower CTLZ_ZERO_UNDEF or CTTZ_ZERO_UNDEF by converting to FP and extracting
// the exponent.
static SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) {
MVT VT = Op.getSimpleValueType();
unsigned EltSize = VT.getScalarSizeInBits();
SDValue Src = Op.getOperand(0);
SDLoc DL(Op);

// We need a FP type that can represent the value.
// TODO: Use f16 for i8 when possible?
MVT FloatEltVT = EltSize == 32 ? MVT::f64 : MVT::f32;
MVT FloatVT = MVT::getVectorVT(FloatEltVT, VT.getVectorElementCount());

// Legal types should have been checked in the RISCVTargetLowering
// constructor.
// TODO: Splitting may make sense in some cases.
assert(DAG.getTargetLoweringInfo().isTypeLegal(FloatVT) &&
"Expected legal float type!");

// For CTTZ_ZERO_UNDEF, we need to extract the lowest set bit using X & -X.
// The trailing zero count is equal to log2 of this single bit value.
if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF) {
SDValue Neg =
DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Src);
Src = DAG.getNode(ISD::AND, DL, VT, Src, Neg);
}

// We have a legal FP type, convert to it.
SDValue FloatVal = DAG.getNode(ISD::UINT_TO_FP, DL, FloatVT, Src);
// Bitcast to integer and shift the exponent to the LSB.
EVT IntVT = FloatVT.changeVectorElementTypeToInteger();
SDValue Bitcast = DAG.getBitcast(IntVT, FloatVal);
unsigned ShiftAmt = FloatEltVT == MVT::f64 ? 52 : 23;
SDValue Shift = DAG.getNode(ISD::SRL, DL, IntVT, Bitcast,
DAG.getConstant(ShiftAmt, DL, IntVT));
// Truncate back to original type to allow vnsrl.
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, Shift);
// The exponent contains log2 of the value in biased form.
unsigned ExponentBias = FloatEltVT == MVT::f64 ? 1023 : 127;

// For trailing zeros, we just need to subtract the bias.
if (Op.getOpcode() == ISD::CTTZ_ZERO_UNDEF)
return DAG.getNode(ISD::SUB, DL, VT, Trunc,
DAG.getConstant(ExponentBias, DL, VT));

// For leading zeros, we need to remove the bias and convert from log2 to
// leading zeros. We can do this by subtracting from (Bias + (EltSize - 1)).
unsigned Adjust = ExponentBias + (EltSize - 1);
return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(Adjust, DL, VT), Trunc);
}

// While RVV has alignment restrictions, we should always be able to load as a
// legal equivalently-sized byte-typed vector instead. This method is
// responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
Expand Down Expand Up @@ -2941,6 +3017,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL);
case ISD::ABS:
return lowerABS(Op, DAG);
case ISD::CTLZ_ZERO_UNDEF:
case ISD::CTTZ_ZERO_UNDEF:
return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
case ISD::VSELECT:
return lowerFixedLengthVectorSelectToRVV(Op, DAG);
case ISD::FCOPYSIGN:
Expand Down

0 comments on commit 0274be2

Please sign in to comment.