Skip to content

Commit

Permalink
[VP][RISCV] Add vp.cttz.elts intrinsic and its RISC-V codegen (#90502)
Browse files Browse the repository at this point in the history
This intrinsic is the VP version of `experimental.cttz.elts`.
  • Loading branch information
mshockwave committed Apr 30, 2024
1 parent df513f8 commit 539f626
Show file tree
Hide file tree
Showing 13 changed files with 459 additions and 2 deletions.
48 changes: 48 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24001,6 +24001,54 @@ Examples:
%also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison


.. _int_vp_cttz_elts:

'``llvm.vp.cttz.elts.*``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic. You can use ```llvm.vp.cttz.elts``` on any
vector of integer elements, both fixed width and scalable.

::

declare i32 @llvm.vp.cttz.elts.i32.v16i32 (<16 x i32> <op>, i1 <is_zero_poison>, <16 x i1> <mask>, i32 <vector_length>)
declare i64 @llvm.vp.cttz.elts.i64.nxv4i32 (<vscale x 4 x i32> <op>, i1 <is_zero_poison>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
declare i64 @llvm.vp.cttz.elts.i64.v256i1 (<256 x i1> <op>, i1 <is_zero_poison>, <256 x i1> <mask>, i32 <vector_length>)

Overview:
"""""""""

This '```llvm.vp.cttz.elts```' intrinsic counts the number of trailing zero
elements of a vector. This is basically the vector-predicated version of
'```llvm.experimental.cttz.elts```'.

Arguments:
""""""""""

The first argument is the vector to be counted. This argument must be a vector
with integer element type. The return type must also be an integer type which is
wide enough to hold the maximum number of elements of the source vector. The
behavior of this intrinsic is undefined if the return type is not wide enough
for the number of elements in the input vector.

The second argument is a constant flag that indicates whether the intrinsic
returns a valid result if the first argument is all zero.

The third operand is the vector mask and has the same number of elements as the
input vector type. The fourth operand is the explicit vector length of the
operation.

Semantics:
""""""""""

The '``llvm.vp.cttz.elts``' intrinsic counts the trailing (least
significant / lowest-numbered) zero elements in the first operand on each
enabled lane. If the first argument is all zero and the second argument is true,
the result is poison. Otherwise, it returns the explicit vector length (i.e. the
fourth operand).

.. _int_vp_sadd_sat:

'``llvm.vp.sadd.sat.*``' Intrinsics
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5307,6 +5307,11 @@ class TargetLowering : public TargetLoweringBase {
/// \returns The expansion result or SDValue() if it fails.
SDValue expandVPCTTZ(SDNode *N, SelectionDAG &DAG) const;

/// Expand VP_CTTZ_ELTS/VP_CTTZ_ELTS_ZERO_UNDEF nodes.
/// \param N Node to expand
/// \returns The expansion result or SDValue() if it fails.
SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;

/// Expand ABS nodes. Expands vector/scalar ABS nodes,
/// vector nodes can only succeed if all operations are legal/custom.
/// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,12 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
llvm_i1_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_i32_ty]>;

def int_vp_cttz_elts : DefaultAttrsIntrinsic<[ llvm_anyint_ty ],
[ llvm_anyvector_ty,
llvm_i1_ty,
LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>,
llvm_i32_ty]>;
}

def int_get_active_lane_mask:
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF, -1, vp_cttz_zero_undef, 1, 2)
END_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF)
END_REGISTER_VP_INTRINSIC(vp_cttz)

// llvm.vp.cttz.elts(x,is_zero_poison,mask,vl)
BEGIN_REGISTER_VP_INTRINSIC(vp_cttz_elts, 2, 3)
VP_PROPERTY_NO_FUNCTIONAL
BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS, 0, vp_cttz_elts, 1, 2)
END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS)
BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF, 0, vp_cttz_elts_zero_undef, 1, 2)
END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF)
END_REGISTER_VP_INTRINSIC(vp_cttz_elts)

// llvm.vp.fshl(x,y,z,mask,vlen)
BEGIN_REGISTER_VP(vp_fshl, 3, 4, VP_FSHL, -1)
VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshl)
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(
Node->getOpcode(), Node->getOperand(1).getValueType());
break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TLI.getCustomOperationAction(*Node);
Expand Down Expand Up @@ -4282,6 +4287,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
case ISD::VECREDUCE_FMINIMUM:
Results.push_back(TLI.expandVecReduce(Node, DAG));
break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Results.push_back(TLI.expandVPCTTZElements(Node, DAG));
break;
case ISD::GLOBAL_OFFSET_TABLE:
case ISD::GlobalAddress:
case ISD::GlobalTLSAddress:
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_CTTZ:
case ISD::CTTZ_ZERO_UNDEF:
case ISD::CTTZ: Res = PromoteIntRes_CTTZ(N); break;
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
case ISD::VP_CTTZ_ELTS:
Res = PromoteIntRes_VP_CttzElements(N);
break;
case ISD::EXTRACT_VECTOR_ELT:
Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
case ISD::LOAD: Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
Expand Down Expand Up @@ -724,6 +728,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTTZ(SDNode *N) {
N->getOperand(2));
}

SDValue DAGTypeLegalizer::PromoteIntRes_VP_CttzElements(SDNode *N) {
SDLoc DL(N);
EVT NewVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
return DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
}

SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N) {
SDLoc dl(N);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_CTLZ(SDNode *N);
SDValue PromoteIntRes_CTPOP_PARITY(SDNode *N);
SDValue PromoteIntRes_CTTZ(SDNode *N);
SDValue PromoteIntRes_VP_CttzElements(SDNode *N);
SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
Expand Down Expand Up @@ -912,6 +913,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_FP_ROUND(SDNode *N);
SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
SDValue SplitVecOp_VP_CttzElements(SDNode *N);

//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
Expand Down Expand Up @@ -1019,6 +1021,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
SDValue WidenVecOp_ExpOp(SDNode *N);
SDValue WidenVecOp_VP_CttzElements(SDNode *N);

/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
Expand Down
42 changes: 42 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3098,6 +3098,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMIN:
Res = SplitVecOp_VP_REDUCE(N, OpNo);
break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = SplitVecOp_VP_CttzElements(N);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -4056,6 +4060,29 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
}

SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
SDLoc DL(N);
EVT ResVT = N->getValueType(0);

SDValue Lo, Hi;
SDValue VecOp = N->getOperand(0);
GetSplitVector(VecOp, Lo, Hi);

auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
auto [EVLLo, EVLHi] =
DAG.SplitEVL(N->getOperand(2), VecOp.getValueType(), DL);
SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);

// if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
// else => EVLLo + (VP_CTTZ_ELTS(Hi) or VP_CTTZ_ELTS_ZERO_UNDEF(Hi)).
SDValue ResLo = DAG.getNode(ISD::VP_CTTZ_ELTS, DL, ResVT, Lo, MaskLo, EVLLo);
SDValue ResLoNotEVL =
DAG.getSetCC(DL, getSetCCResultType(ResVT), ResLo, VLo, ISD::SETNE);
SDValue ResHi = DAG.getNode(N->getOpcode(), DL, ResVT, Hi, MaskHi, EVLHi);
return DAG.getSelect(DL, ResVT, ResLoNotEVL, ResLo,
DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
}

//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -6161,6 +6188,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMIN:
Res = WidenVecOp_VP_REDUCE(N);
break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = WidenVecOp_VP_CttzElements(N);
break;
}

// If Res is null, the sub-method took care of registering the result.
Expand Down Expand Up @@ -6924,6 +6955,17 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
DAG.getVectorIdxConstant(0, DL));
}

SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
SDLoc DL(N);
SDValue Source = GetWidenedVector(N->getOperand(0));
EVT SrcVT = Source.getValueType();
SDValue Mask =
GetWidenedMask(N->getOperand(1), SrcVT.getVectorElementCount());

return DAG.getNode(N->getOpcode(), DL, N->getValueType(0),
{Source, Mask, N->getOperand(2)}, N->getFlags());
}

//===----------------------------------------------------------------------===//
// Vector Widening Utilities
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8076,6 +8076,11 @@ static unsigned getISDForVPIntrinsic(const VPIntrinsic &VPIntrin) {
ResOPC = IsZeroUndef ? ISD::VP_CTTZ_ZERO_UNDEF : ISD::VP_CTTZ;
break;
}
case Intrinsic::vp_cttz_elts: {
bool IsZeroPoison = cast<ConstantInt>(VPIntrin.getArgOperand(1))->isOne();
ResOPC = IsZeroPoison ? ISD::VP_CTTZ_ELTS_ZERO_UNDEF : ISD::VP_CTTZ_ELTS;
break;
}
#define HELPER_MAP_VPID_TO_VPSD(VPID, VPSD) \
case Intrinsic::VPID: \
ResOPC = ISD::VPSD; \
Expand Down Expand Up @@ -8428,7 +8433,9 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic(
case ISD::VP_CTLZ:
case ISD::VP_CTLZ_ZERO_UNDEF:
case ISD::VP_CTTZ:
case ISD::VP_CTTZ_ZERO_UNDEF: {
case ISD::VP_CTTZ_ZERO_UNDEF:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
case ISD::VP_CTTZ_ELTS: {
SDValue Result =
DAG.getNode(Opcode, DL, VTs, {OpValues[0], OpValues[2], OpValues[3]});
setValue(&VPIntrin, Result);
Expand Down
33 changes: 33 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9074,6 +9074,39 @@ SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
}

SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
SelectionDAG &DAG) const {
// %cond = to_bool_vec %source
// %splat = splat /*val=*/VL
// %tz = step_vector
// %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
// %r = vp.reduce.umin %v
SDLoc DL(N);
SDValue Source = N->getOperand(0);
SDValue Mask = N->getOperand(1);
SDValue EVL = N->getOperand(2);
EVT SrcVT = Source.getValueType();
EVT ResVT = N->getValueType(0);
EVT ResVecVT =
EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());

// Convert to boolean vector.
if (SrcVT.getScalarType() != MVT::i1) {
SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
SrcVT.getVectorElementCount());
Source = DAG.getNode(ISD::VP_SETCC, DL, SrcVT, Source, AllZero,
DAG.getCondCode(ISD::SETNE), Mask, EVL);
}

SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
SDValue Select =
DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
}

SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
bool IsNegative) const {
SDLoc dl(N);
Expand Down
46 changes: 45 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
ISD::VP_SADDSAT, ISD::VP_UADDSAT, ISD::VP_SSUBSAT,
ISD::VP_USUBSAT};
ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF};

static const unsigned FloatingPointVPOps[] = {
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
Expand Down Expand Up @@ -759,6 +759,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
{ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
Expand);

setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
Custom);

setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);

setOperationAction(
Expand Down Expand Up @@ -5341,6 +5344,44 @@ RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
return Res;
}

SDValue RISCVTargetLowering::lowerVPCttzElements(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
MVT XLenVT = Subtarget.getXLenVT();
SDValue Source = Op->getOperand(0);
MVT SrcVT = Source.getSimpleValueType();
SDValue Mask = Op->getOperand(1);
SDValue EVL = Op->getOperand(2);

if (SrcVT.isFixedLengthVector()) {
MVT ContainerVT = getContainerForFixedLengthVector(SrcVT);
Source = convertToScalableVector(ContainerVT, Source, DAG, Subtarget);
Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
Subtarget);
SrcVT = ContainerVT;
}

// Convert to boolean vector.
if (SrcVT.getScalarType() != MVT::i1) {
SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
SrcVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorElementCount());
Source = DAG.getNode(RISCVISD::SETCC_VL, DL, SrcVT,
{Source, AllZero, DAG.getCondCode(ISD::SETNE),
DAG.getUNDEF(SrcVT), Mask, EVL});
}

SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Source, Mask, EVL);
if (Op->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
// In this case, we can interpret poison as -1, so nothing to do further.
return Res;

// Convert -1 to VL.
SDValue SetCC =
DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
Res = DAG.getSelect(DL, XLenVT, SetCC, EVL, Res);
return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
}

// While RVV has alignment restrictions, we should always be able to load as a
// legal equivalently-sized byte-typed vector instead. This method is
// responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
Expand Down Expand Up @@ -6595,6 +6636,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
return lowerVPREDUCE(Op, DAG);
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
return lowerVPCttzElements(Op, DAG);
case ISD::UNDEF: {
MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
return convertFromScalableVector(Op.getSimpleValueType(),
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPStridedLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPStridedStore(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPCttzElements(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
unsigned ExtendOpc) const;
SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading

0 comments on commit 539f626

Please sign in to comment.