Skip to content

Commit b6dd511

Browse files
authored
[X86] AVX512 optimised CTLZ/CTTZ implementations for i256/i512 scalars (#164671)
Make use of AVX512 VPLZCNT/VPOPCNT to perform the big integer bit counts per vector element and then use VPCOMPRESS to extract the first non-zero element result. There's more we can do here (widen/split other vector widths etc.) - but this is a good starting point.
1 parent 62d1a08 commit b6dd511

File tree

2 files changed

+370
-418
lines changed

2 files changed

+370
-418
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,6 +2654,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654
ISD::AVGCEILU,
26552655
ISD::AVGFLOORS,
26562656
ISD::AVGFLOORU,
2657+
ISD::CTLZ,
2658+
ISD::CTTZ,
2659+
ISD::CTLZ_ZERO_UNDEF,
2660+
ISD::CTTZ_ZERO_UNDEF,
26572661
ISD::BITREVERSE,
26582662
ISD::ADD,
26592663
ISD::FADD,
@@ -55162,6 +55166,65 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255166
return combineFneg(N, DAG, DCI, Subtarget);
5516355167
}
5516455168

55169+
// Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55170+
// vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55171+
// Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55172+
// the result to remove all zero elements (passthru is set to scalar bitwidth if
55173+
// all elements are zero) and extract the lowest compressed element.
55174+
static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55175+
TargetLowering::DAGCombinerInfo &DCI,
55176+
const X86Subtarget &Subtarget) {
55177+
EVT VT = N->getValueType(0);
55178+
SDValue N0 = N->getOperand(0);
55179+
unsigned Opc = N->getOpcode();
55180+
unsigned SizeInBits = VT.getSizeInBits();
55181+
assert((Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF || Opc == ISD::CTTZ ||
55182+
Opc == ISD::CTTZ_ZERO_UNDEF) &&
55183+
"Unsupported bit count");
55184+
55185+
if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55186+
((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55187+
(SizeInBits == 256 && Subtarget.hasVLX() &&
55188+
X86::mayFoldLoad(N0, Subtarget)))) {
55189+
MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55190+
MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55191+
SDValue Vec = DAG.getBitcast(VecVT, N0);
55192+
SDLoc DL(N);
55193+
55194+
SmallVector<int, 8> RevMask;
55195+
SmallVector<SDValue, 8> Offsets;
55196+
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55197+
RevMask.push_back((int)((E - 1) - I));
55198+
Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55199+
}
55200+
55201+
// CTLZ - reverse the elements as we want the top non-zero element at the
55202+
// bottom for compression.
55203+
unsigned VecOpc = ISD::CTTZ;
55204+
if (Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF) {
55205+
VecOpc = ISD::CTLZ;
55206+
Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55207+
}
55208+
55209+
SDValue PassThrough = DAG.getUNDEF(VecVT);
55210+
if (Opc == ISD::CTLZ || Opc == ISD::CTTZ)
55211+
PassThrough = DAG.getConstant(SizeInBits, DL, VecVT);
55212+
55213+
SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55214+
DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55215+
SDValue Cnt = DAG.getNode(VecOpc, DL, VecVT, Vec);
55216+
Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55217+
DAG.getBuildVector(VecVT, DL, Offsets));
55218+
Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55219+
PassThrough);
55220+
Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55221+
DAG.getVectorIdxConstant(0, DL));
55222+
return DAG.getZExtOrTrunc(Cnt, DL, VT);
55223+
}
55224+
55225+
return SDValue();
55226+
}
55227+
5516555228
static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655229
TargetLowering::DAGCombinerInfo &DCI,
5516755230
const X86Subtarget &Subtarget) {
@@ -60885,6 +60948,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560948
case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660949
case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760950
case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60951+
case ISD::CTLZ:
60952+
case ISD::CTTZ:
60953+
case ISD::CTLZ_ZERO_UNDEF:
60954+
case ISD::CTTZ_ZERO_UNDEF:return combineCTZ(N, DAG, DCI, Subtarget);
6088860955
case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960956
case ISD::AVGCEILS:
6089060957
case ISD::AVGCEILU:

0 commit comments

Comments
 (0)