Skip to content

Commit

Permalink
[AArch64] Add intrinsic to count trailing zero elements
Browse files Browse the repository at this point in the history
This patch introduces an experimental intrinsic for counting the
trailing zero elements in a vector. The intrinsic has generic expansion
in SelectionDAGBuilder, and for AArch64 there is a pattern which matches
to brkb & cntp instructions where SVE is enabled.

The intrinsic has a second operand, is_zero_poison, similar to the
existing cttz intrinsic.

These changes have been split out from D158291.
  • Loading branch information
kmclaughlin-arm committed Oct 31, 2023
1 parent 00a8314 commit 3b786f2
Show file tree
Hide file tree
Showing 13 changed files with 972 additions and 0 deletions.
39 changes: 39 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18497,6 +18497,45 @@ Arguments:
Both arguments must be vectors of the same type whereby their logical
concatenation matches the result type.

'``llvm.experimental.cttz.elts``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

This is an overloaded intrinsic. You can use ```llvm.experimental.cttz.elts```
on any vector of integer elements, both fixed width and scalable.

::

declare i8 @llvm.experimental.cttz.elts.i8.v8i1(<8 x i1> <src>, i1 <is_zero_poison>)

Overview:
"""""""""

The '``llvm.experimental.cttz.elts``' intrinsic counts the number of trailing
zero elements of a vector.

Arguments:
""""""""""

The first argument is the vector to be counted. This argument must be a vector
with integer element type. The return type must also be an integer type which is
wide enough to hold the maximum number of elements of the source vector. The
behaviour of this intrinsic is undefined if the return type is not wide enough
for the number of elements in the input vector.

The second argument is a constant flag that indicates whether the intrinsic
returns a valid result if the first argument is all zero. If the first argument
is all zero and the second argument is true, the result is poison.

Semantics:
""""""""""

The '``llvm.experimental.cttz.elts``' intrinsic counts the trailing (least
significant) zero elements in a vector. If ``src == 0`` the result is the
number of elements in the input vector.

'``llvm.experimental.vector.splice``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ class TargetLoweringBase {
return true;
}

/// Return true if the @llvm.experimental.cttz.elts intrinsic should be
/// expanded using generic code in SelectionDAGBuilder.
virtual bool shouldExpandCttzElements(EVT VT) const { return true; }

// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,11 @@ def int_experimental_get_vector_length:
[IntrNoMem, IntrNoSync, IntrWillReturn,
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;

def int_experimental_cttz_elts:
DefaultAttrsIntrinsic<[llvm_anyint_ty],
[llvm_anyvector_ty, llvm_i1_ty],
[IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>]>;

def int_experimental_vp_splice:
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMMatchType<0>,
Expand Down
56 changes: 56 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7514,6 +7514,62 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
setValue(&I, Trunc);
return;
}
case Intrinsic::experimental_cttz_elts: {
auto DL = getCurSDLoc();
SDValue Op = getValue(I.getOperand(0));
EVT OpVT = Op.getValueType();

if (!TLI.shouldExpandCttzElements(OpVT)) {
visitTargetIntrinsic(I, Intrinsic);
return;
}

if (OpVT.getScalarType() != MVT::i1) {
// Compare the input vector elements to zero & use to count trailing zeros
SDValue AllZero = DAG.getConstant(0, DL, OpVT);
OpVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
OpVT.getVectorElementCount());
Op = DAG.getSetCC(DL, OpVT, Op, AllZero, ISD::SETNE);
}

// Find the smallest "sensible" element type to use for the expansion.
ConstantRange CR(
APInt(64, OpVT.getVectorElementCount().getKnownMinValue()));
if (OpVT.isScalableVT())
CR = CR.umul_sat(getVScaleRange(I.getCaller(), 64));

// If the zero-is-poison flag is set, we can assume the upper limit
// of the result is VF-1.
if (!cast<ConstantSDNode>(getValue(I.getOperand(1)))->isZero())
CR = CR.subtract(APInt(64, 1));

unsigned EltWidth = I.getType()->getScalarSizeInBits();
EltWidth = std::min(EltWidth, (unsigned)CR.getActiveBits());
EltWidth = std::max(llvm::bit_ceil(EltWidth), (unsigned)8);

MVT NewEltTy = MVT::getIntegerVT(EltWidth);

// Create the new vector type & get the vector length
EVT NewVT = EVT::getVectorVT(*DAG.getContext(), NewEltTy,
OpVT.getVectorElementCount());

SDValue VL =
DAG.getElementCount(DL, NewEltTy, OpVT.getVectorElementCount());

SDValue StepVec = DAG.getStepVector(DL, NewVT);
SDValue SplatVL = DAG.getSplat(NewVT, DL, VL);
SDValue StepVL = DAG.getNode(ISD::SUB, DL, NewVT, SplatVL, StepVec);
SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, Op);
SDValue And = DAG.getNode(ISD::AND, DL, NewVT, StepVL, Ext);
SDValue Max = DAG.getNode(ISD::VECREDUCE_UMAX, DL, NewEltTy, And);
SDValue Sub = DAG.getNode(ISD::SUB, DL, NewEltTy, VL, Max);

EVT RetTy = TLI.getValueType(DAG.getDataLayout(), I.getType());
SDValue Ret = DAG.getZExtOrTrunc(Sub, DL, RetTy);

setValue(&I, Ret);
return;
}
case Intrinsic::vector_insert: {
SDValue Vec = getValue(I.getOperand(0));
SDValue SubVec = getValue(I.getOperand(1));
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,10 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}

bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
return !Subtarget->hasSVEorSME() || VT != MVT::nxv16i1;
}

void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT,
bool StreamingSVE) {
assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
Expand Down Expand Up @@ -2634,6 +2638,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::MRRS)
MAKE_CASE(AArch64ISD::MSRR)
MAKE_CASE(AArch64ISD::RSHRNB_I)
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
}
#undef MAKE_CASE
return nullptr;
Expand Down Expand Up @@ -5338,6 +5343,12 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
return SDValue();
}
case Intrinsic::experimental_cttz_elts: {
SDValue NewCttzElts =
DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, Op.getOperand(1));

return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ enum NodeType : unsigned {
PTEST_ANY,
PTRUE,

CTTZ_ELTS,

BITREVERSE_MERGE_PASSTHRU,
BSWAP_MERGE_PASSTHRU,
REVH_MERGE_PASSTHRU,
Expand Down Expand Up @@ -927,6 +929,8 @@ class AArch64TargetLowering : public TargetLowering {

bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;

bool shouldExpandCttzElements(EVT VT) const override;

/// If a change in streaming mode is required on entry to/return from a
/// function call it emits and returns the corresponding SMSTART or SMSTOP node.
/// \p Entry tells whether this is before/after the Call, which is necessary
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,9 @@ def AArch64rshrnb_pf : PatFrags<(ops node:$rs, node:$i),
[(AArch64rshrnb node:$rs, node:$i),
(int_aarch64_sve_rshrnb node:$rs, node:$i)]>;

def AArch64CttzElts : SDNode<"AArch64ISD::CTTZ_ELTS", SDTypeProfile<1, 1,
[SDTCisInt<0>, SDTCisVec<1>]>, []>;

// Match add node and also treat an 'or' node is as an 'add' if the or'ed operands
// have no common bits.
def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,11 @@ let Predicates = [HasSVEorSME] in {
defm CNTW_XPiI : sve_int_count<0b100, "cntw", int_aarch64_sve_cntw>;
defm CNTD_XPiI : sve_int_count<0b110, "cntd", int_aarch64_sve_cntd>;
defm CNTP_XPP : sve_int_pcount_pred<0b0000, "cntp", int_aarch64_sve_cntp>;

def : Pat<(i64 (AArch64CttzElts nxv16i1:$Op1)),
(i64 (!cast<Instruction>(CNTP_XPP_B)
(nxv16i1 (!cast<Instruction>(BRKB_PPzP) (PTRUE_B 31), nxv16i1:$Op1)),
(nxv16i1 (!cast<Instruction>(BRKB_PPzP) (PTRUE_B 31), nxv16i1:$Op1))))>;
}

defm INCB_XPiI : sve_int_pred_pattern_a<0b000, "incb", add, int_aarch64_sve_cntb>;
Expand Down Expand Up @@ -2049,6 +2054,17 @@ let Predicates = [HasSVEorSME] in {
defm INCP_ZP : sve_int_count_v<0b10000, "incp">;
defm DECP_ZP : sve_int_count_v<0b10100, "decp">;

def : Pat<(i64 (add GPR64:$Op1, (i64 (AArch64CttzElts nxv16i1:$Op2)))),
(i64 (!cast<Instruction>(INCP_XP_B)
(nxv16i1 (!cast<Instruction>(BRKB_PPzP) (PTRUE_B 31), nxv16i1:$Op2)),
GPR64:$Op1))>;

def : Pat<(i32 (add GPR32:$Op1, (trunc (i64 (AArch64CttzElts nxv16i1:$Op2))))),
(i32 (EXTRACT_SUBREG (i64 (!cast<Instruction>(INCP_XP_B)
(nxv16i1 (!cast<Instruction>(BRKB_PPzP) (PTRUE_B 31), nxv16i1:$Op2)),
(INSERT_SUBREG (i64 (IMPLICIT_DEF)), GPR32:$Op1, sub_32))),
sub_32))>;

defm INDEX_RR : sve_int_index_rr<"index", AArch64mul_p_oneuse>;
defm INDEX_IR : sve_int_index_ir<"index", AArch64mul_p, AArch64mul_p_oneuse>;
defm INDEX_RI : sve_int_index_ri<"index">;
Expand Down

0 comments on commit 3b786f2

Please sign in to comment.