@@ -2654,6 +2654,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
26542654 ISD::AVGCEILU,
26552655 ISD::AVGFLOORS,
26562656 ISD::AVGFLOORU,
2657+ ISD::CTLZ,
2658+ ISD::CTTZ,
2659+ ISD::CTLZ_ZERO_UNDEF,
2660+ ISD::CTTZ_ZERO_UNDEF,
26572661 ISD::BITREVERSE,
26582662 ISD::ADD,
26592663 ISD::FADD,
@@ -55162,6 +55166,65 @@ static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
5516255166 return combineFneg(N, DAG, DCI, Subtarget);
5516355167}
5516455168
55169+ // Fold i256/i512 CTLZ/CTTZ patterns to make use of AVX512
55170+ // vXi64 CTLZ/CTTZ and VECTOR_COMPRESS.
55171+ // Compute the CTLZ/CTTZ of each element, add the element's bit offset, compress
55172+ // the result to remove all zero elements (passthru is set to scalar bitwidth if
55173+ // all elements are zero) and extract the lowest compressed element.
55174+ static SDValue combineCTZ(SDNode *N, SelectionDAG &DAG,
55175+ TargetLowering::DAGCombinerInfo &DCI,
55176+ const X86Subtarget &Subtarget) {
55177+ EVT VT = N->getValueType(0);
55178+ SDValue N0 = N->getOperand(0);
55179+ unsigned Opc = N->getOpcode();
55180+ unsigned SizeInBits = VT.getSizeInBits();
55181+ assert((Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF || Opc == ISD::CTTZ ||
55182+ Opc == ISD::CTTZ_ZERO_UNDEF) &&
55183+ "Unsupported bit count");
55184+
55185+ if (VT.isScalarInteger() && Subtarget.hasCDI() &&
55186+ ((SizeInBits == 512 && Subtarget.useAVX512Regs()) ||
55187+ (SizeInBits == 256 && Subtarget.hasVLX() &&
55188+ X86::mayFoldLoad(N0, Subtarget)))) {
55189+ MVT VecVT = MVT::getVectorVT(MVT::i64, SizeInBits / 64);
55190+ MVT BoolVT = VecVT.changeVectorElementType(MVT::i1);
55191+ SDValue Vec = DAG.getBitcast(VecVT, N0);
55192+ SDLoc DL(N);
55193+
55194+ SmallVector<int, 8> RevMask;
55195+ SmallVector<SDValue, 8> Offsets;
55196+ for (unsigned I = 0, E = VecVT.getVectorNumElements(); I != E; ++I) {
55197+ RevMask.push_back((int)((E - 1) - I));
55198+ Offsets.push_back(DAG.getConstant(I * 64, DL, MVT::i64));
55199+ }
55200+
55201+ // CTLZ - reverse the elements as we want the top non-zero element at the
55202+ // bottom for compression.
55203+ unsigned VecOpc = ISD::CTTZ;
55204+ if (Opc == ISD::CTLZ || Opc == ISD::CTLZ_ZERO_UNDEF) {
55205+ VecOpc = ISD::CTLZ;
55206+ Vec = DAG.getVectorShuffle(VecVT, DL, Vec, Vec, RevMask);
55207+ }
55208+
55209+ SDValue PassThrough = DAG.getUNDEF(VecVT);
55210+ if (Opc == ISD::CTLZ || Opc == ISD::CTTZ)
55211+ PassThrough = DAG.getConstant(SizeInBits, DL, VecVT);
55212+
55213+ SDValue IsNonZero = DAG.getSetCC(DL, BoolVT, Vec,
55214+ DAG.getConstant(0, DL, VecVT), ISD::SETNE);
55215+ SDValue Cnt = DAG.getNode(VecOpc, DL, VecVT, Vec);
55216+ Cnt = DAG.getNode(ISD::ADD, DL, VecVT, Cnt,
55217+ DAG.getBuildVector(VecVT, DL, Offsets));
55218+ Cnt = DAG.getNode(ISD::VECTOR_COMPRESS, DL, VecVT, Cnt, IsNonZero,
55219+ PassThrough);
55220+ Cnt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cnt,
55221+ DAG.getVectorIdxConstant(0, DL));
55222+ return DAG.getZExtOrTrunc(Cnt, DL, VT);
55223+ }
55224+
55225+ return SDValue();
55226+ }
55227+
5516555228static SDValue combineBITREVERSE(SDNode *N, SelectionDAG &DAG,
5516655229 TargetLowering::DAGCombinerInfo &DCI,
5516755230 const X86Subtarget &Subtarget) {
@@ -60885,6 +60948,10 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
6088560948 case ISD::AND: return combineAnd(N, DAG, DCI, Subtarget);
6088660949 case ISD::OR: return combineOr(N, DAG, DCI, Subtarget);
6088760950 case ISD::XOR: return combineXor(N, DAG, DCI, Subtarget);
60951+ case ISD::CTLZ:
60952+ case ISD::CTTZ:
60953+ case ISD::CTLZ_ZERO_UNDEF:
60954+ case ISD::CTTZ_ZERO_UNDEF:return combineCTZ(N, DAG, DCI, Subtarget);
6088860955 case ISD::BITREVERSE: return combineBITREVERSE(N, DAG, DCI, Subtarget);
6088960956 case ISD::AVGCEILS:
6089060957 case ISD::AVGCEILU:
0 commit comments