Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86] Improve helper for simplifying demanded bits of compares #84360

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

goldsteinn
Copy link
Contributor

  • [X86] Add tests for folding icmp of v8i32 -> fcmp of v8f32 on AVX; NFC
  • [X86] Try Folding icmp of v8i32 -> fcmp of v8f32 on AVX
  • [X86] Improve helper for simplifying demanded bits of compares

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-x86

Author: None (goldsteinn)

Changes
  • [X86] Add tests for folding icmp of v8i32 -> fcmp of v8f32 on AVX; NFC
  • [X86] Try Folding icmp of v8i32 -> fcmp of v8f32 on AVX
  • [X86] Improve helper for simplifying demanded bits of compares

Patch is 444.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84360.diff

41 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+28)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+315-12)
  • (modified) llvm/lib/Target/X86/X86InstrInfo.cpp (+40)
  • (modified) llvm/lib/Target/X86/X86InstrInfo.h (+3)
  • (modified) llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll (+9-18)
  • (modified) llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-zext.ll (+9-18)
  • (added) llvm/test/CodeGen/X86/cmpf-avx.ll (+250)
  • (modified) llvm/test/CodeGen/X86/combine-sse41-intrinsics.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/fpclamptosat_vec.ll (+40-68)
  • (modified) llvm/test/CodeGen/X86/i64-to-float.ll (+1-3)
  • (modified) llvm/test/CodeGen/X86/masked_compressstore.ll (+17-17)
  • (modified) llvm/test/CodeGen/X86/masked_expandload.ll (+24-24)
  • (modified) llvm/test/CodeGen/X86/masked_gather.ll (+64-70)
  • (modified) llvm/test/CodeGen/X86/masked_load.ll (+3-5)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+36-33)
  • (modified) llvm/test/CodeGen/X86/masked_store_trunc.ll (+29-48)
  • (modified) llvm/test/CodeGen/X86/masked_store_trunc_ssat.ll (+260-347)
  • (modified) llvm/test/CodeGen/X86/masked_store_trunc_usat.ll (+53-104)
  • (modified) llvm/test/CodeGen/X86/nontemporal-loads.ll (+7-11)
  • (modified) llvm/test/CodeGen/X86/pr48215.ll (+9-10)
  • (modified) llvm/test/CodeGen/X86/pr81136.ll (+7-9)
  • (modified) llvm/test/CodeGen/X86/sat-add.ll (+3-10)
  • (modified) llvm/test/CodeGen/X86/setcc-lowering.ll (+4-3)
  • (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+52-59)
  • (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+221-264)
  • (modified) llvm/test/CodeGen/X86/v8i1-masks.ll (+6-10)
  • (modified) llvm/test/CodeGen/X86/var-permute-256.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/vec_saddo.ll (+45-45)
  • (modified) llvm/test/CodeGen/X86/vec_ssubo.ll (+45-45)
  • (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+33-38)
  • (modified) llvm/test/CodeGen/X86/vector-constrained-fp-intrinsics.ll (+7-11)
  • (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+22-29)
  • (modified) llvm/test/CodeGen/X86/vector-popcnt-256-ult-ugt.ll (+1551-1584)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (-58)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-or-bool.ll (+11-13)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-xor-bool.ll (+10-12)
  • (modified) llvm/test/CodeGen/X86/vector-sext.ll (+3-6)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+121-196)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+17-65)
  • (modified) llvm/test/CodeGen/X86/vector-unsigned-cmp.ll (+9-8)
  • (modified) llvm/test/CodeGen/X86/vsel-cmp-load.ll (+3-9)
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a639cba5e35a80..2e1443b97d7a61 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -816,6 +816,18 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     }
     break;
   }
+  case ISD::SINT_TO_FP: {
+    EVT InnerVT = Op.getOperand(0).getValueType();
+    if (DemandedBits.isSignMask() &&
+        VT.getScalarSizeInBits() == InnerVT.getScalarSizeInBits())
+      return DAG.getBitcast(VT, Op.getOperand(0));
+    break;
+  }
+  case ISD::UINT_TO_FP: {
+    if (DemandedBits.isSignMask())
+      return DAG.getConstant(0, SDLoc(Op), VT);
+    break;
+  }
   case ISD::SIGN_EXTEND_INREG: {
     // If none of the extended bits are demanded, eliminate the sextinreg.
     SDValue Op0 = Op.getOperand(0);
@@ -2313,6 +2325,22 @@ bool TargetLowering::SimplifyDemandedBits(
     Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
     break;
   }
+  case ISD::SINT_TO_FP: {
+    EVT InnerVT = Op.getOperand(0).getValueType();
+    if (DemandedBits.isSignMask() &&
+        VT.getScalarSizeInBits() == InnerVT.getScalarSizeInBits())
+      return TLO.CombineTo(Op, TLO.DAG.getBitcast(VT, Op.getOperand(0)));
+
+    Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
+    break;
+  }
+  case ISD::UINT_TO_FP: {
+    if (DemandedBits.isSignMask())
+      return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
+
+    Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
+    break;
+  }
   case ISD::SIGN_EXTEND_INREG: {
     SDValue Op0 = Op.getOperand(0);
     EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 94c4bbc4a09993..240388657511fb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23352,6 +23352,136 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
     }
   }
 
+  // We get bad codegen for v8i32 compares on avx targets (without avx2) so if
+  // possible convert to a v8f32 compare.
+  if (VTOp0 == MVT::v8i32 && Subtarget.hasAVX() && !Subtarget.hasAVX2()) {
+    std::optional<KnownBits> KnownOps[2];
+    // Check if an op is known to be in a certain range.
+    auto OpInRange = [&DAG, Op, &KnownOps](unsigned OpNo, bool CmpLT,
+                                           const APInt Bound) {
+      if (!KnownOps[OpNo].has_value())
+        KnownOps[OpNo] = DAG.computeKnownBits(Op.getOperand(OpNo));
+
+      if (KnownOps[OpNo]->isUnknown())
+        return false;
+
+      std::optional<bool> Res;
+      if (CmpLT)
+        Res = KnownBits::ult(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
+      else
+        Res = KnownBits::ugt(*KnownOps[OpNo], KnownBits::makeConstant(Bound));
+      return Res.has_value() && *Res;
+    };
+
+    bool OkayCvt = false;
+    bool OkayBitcast = false;
+
+    const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(MVT::f32);
+
+    // For cvt up to 1 << (Significand Precision), (1 << 24 for ieee float)
+    const APInt MaxConvertableCvt =
+        APInt::getOneBitSet(32, APFloat::semanticsPrecision(Sem));
+    // For bitcast up to (and including) first inf representation (0x7f800000 +
+    // 1 for ieee float)
+    const APInt MaxConvertableBitcast =
+        APFloat::getInf(Sem).bitcastToAPInt() + 1;
+    // For bitcast we also exclude de-norm values. This is absolutely necessary
+    // for strict semantic correctness, but DAZ (de-norm as zero) will break if
+    // we don't have this check.
+    const APInt MinConvertableBitcast =
+        APFloat::getSmallestNormalized(Sem).bitcastToAPInt() - 1;
+
+    assert(
+        MaxConvertableBitcast.getBitWidth() == 32 &&
+        MaxConvertableCvt == (1U << 24) &&
+        MaxConvertableBitcast == 0x7f800001 &&
+        MinConvertableBitcast.isNonNegative() &&
+        MaxConvertableBitcast.sgt(MinConvertableBitcast) &&
+        "This transform has only been verified to IEEE Single Precision Float");
+
+    // For bitcast we need both lhs/op1 u< MaxConvertableBitcast
+    // NB: It might be worth it to enable to bitcast version for unsigned avx2
+    // comparisons as they typically require multiple instructions to lower
+    // (they don't fit `vpcmpeq`/`vpcmpgt` well).
+    if (OpInRange(1, /*CmpLT*/ true, MaxConvertableBitcast) &&
+        OpInRange(1, /*CmpLT*/ false, MinConvertableBitcast) &&
+        OpInRange(0, /*CmpLT*/ true, MaxConvertableBitcast) &&
+        OpInRange(0, /*CmpLT*/ false, MinConvertableBitcast)) {
+      OkayBitcast = true;
+    }
+    // We want to convert icmp -> fcmp using `sitofp` iff one of the converts
+    // will be constant folded.
+    else if ((DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op1)) ||
+              DAG.isConstantValueOfAnyType(peekThroughBitcasts(Op0)))) {
+      if (isUnsignedIntSetCC(Cond)) {
+        // For cvt + unsigned compare we need both lhs/rhs >= 0 and either lhs
+        // or rhs < MaxConvertableCvt
+
+        if (OpInRange(1, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
+            OpInRange(0, /*CmpLT*/ true, APInt::getSignedMinValue(32)) &&
+            (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
+             OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt)))
+          OkayCvt = true;
+      } else {
+        // For cvt + signed compare we need  abs(lhs) or abs(rhs) <
+        // MaxConvertableCvt
+        if (OpInRange(1, /*CmpLT*/ true, MaxConvertableCvt) ||
+            OpInRange(1, /*CmpLT*/ false, -MaxConvertableCvt) ||
+            OpInRange(0, /*CmpLT*/ true, MaxConvertableCvt) ||
+            OpInRange(0, /*CmpLT*/ false, -MaxConvertableCvt))
+          OkayCvt = true;
+      }
+    }
+    // TODO: If we can't prove any of the ranges, we could unconditionally lower
+    // `(icmp eq lhs, rhs)` as `(icmp eq (int_to_fp (xor lhs, rhs)), zero)`
+    if (OkayBitcast || OkayCvt) {
+      switch (Cond) {
+      default:
+        llvm_unreachable("Unexpected SETCC condition");
+        // Get the new FP condition. Note for the unsigned conditions we have
+        // verified its okay to convert to the signed version.
+      case ISD::SETULT:
+      case ISD::SETLT:
+        Cond = ISD::SETOLT;
+        break;
+      case ISD::SETUGT:
+      case ISD::SETGT:
+        Cond = ISD::SETOGT;
+        break;
+      case ISD::SETULE:
+      case ISD::SETLE:
+        Cond = ISD::SETOLE;
+        break;
+      case ISD::SETUGE:
+      case ISD::SETGE:
+        Cond = ISD::SETOGE;
+        break;
+      case ISD::SETEQ:
+        Cond = ISD::SETOEQ;
+        break;
+      case ISD::SETNE:
+        Cond = ISD::SETONE;
+        break;
+      }
+
+      MVT FpVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
+      SDNodeFlags Flags;
+      Flags.setNoNaNs(true);
+      Flags.setNoInfs(true);
+      Flags.setNoSignedZeros(true);
+      if (OkayBitcast) {
+        Op0 = DAG.getBitcast(FpVT, Op0);
+        Op1 = DAG.getBitcast(FpVT, Op1);
+      } else {
+        Op0 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op0);
+        Op1 = DAG.getNode(ISD::SINT_TO_FP, dl, FpVT, Op1);
+      }
+      Op0->setFlags(Flags);
+      Op1->setFlags(Flags);
+      return DAG.getSetCC(dl, VT, Op0, Op1, Cond);
+    }
+  }
+
   // Break 256-bit integer vector compare into smaller ones.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
     return splitIntVSETCC(VT, Op0, Op1, Cond, DAG, dl);
@@ -41163,6 +41293,154 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Simplify a decomposed (sext (setcc)). Assumes prior check that
+// bitwidth(sext)==bitwidth(setcc operands).
+static SDValue simplifySExtOfDecomposedSetCCImpl(
+    SelectionDAG &DAG, SDLoc &DL, ISD::CondCode CC, SDValue Op0, SDValue Op1,
+    const APInt &OriginalDemandedBits, const APInt &OriginalDemandedElts,
+    bool AllowNOT, unsigned Depth) {
+  // Possible TODO: We could handle any power of two demanded bit + unsigned
+  // comparison. There are no x86 specific comparisons that are unsigned so its
+  // unneeded.
+  if (!OriginalDemandedBits.isSignMask())
+    return SDValue();
+
+  EVT OpVT = Op0.getValueType();
+  // We need need nofpclass(nan inf nzero) to handle floats.
+  auto hasOkayFPFlags = [](SDValue Op) {
+    return Op->getFlags().hasNoNaNs() && Op->getFlags().hasNoInfs() &&
+           Op->getFlags().hasNoSignedZeros();
+  };
+
+  if (OpVT.isFloatingPoint() && !hasOkayFPFlags(Op0))
+    return SDValue();
+
+  auto ValsEq = [OpVT](const APInt &V0, APInt V1) -> bool {
+    if (OpVT.isFloatingPoint()) {
+      const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
+      return V0.eq(APFloat(Sem, V1).bitcastToAPInt());
+    }
+    return V0.eq(V1);
+  };
+
+  // Assume we canonicalized constants to Op1. That isn't always true but we
+  // call this function twice with inverted CC/Operands so its fine either way.
+  APInt Op1C;
+  unsigned ValWidth = OriginalDemandedBits.getBitWidth();
+  if (ISD::isConstantSplatVectorAllZeros(Op1.getNode())) {
+    Op1C = APInt::getZero(ValWidth);
+  } else if (ISD::isConstantSplatVectorAllOnes(Op1.getNode())) {
+    Op1C = APInt::getAllOnes(ValWidth);
+  } else if (auto *C = dyn_cast<ConstantFPSDNode>(Op1)) {
+    Op1C = C->getValueAPF().bitcastToAPInt();
+  } else if (auto *C = dyn_cast<ConstantSDNode>(Op1)) {
+    Op1C = C->getAPIntValue();
+  } else if (ISD::isConstantSplatVector(Op1.getNode(), Op1C)) {
+    // isConstantSplatVector sets `Op1C`.
+  } else {
+    return SDValue();
+  }
+
+  bool Not = false;
+  bool Okay = false;
+  assert(OriginalDemandedBits.getBitWidth() == Op1C.getBitWidth() &&
+         "Invalid constant operand");
+
+  switch (CC) {
+  case ISD::SETGE:
+  case ISD::SETOGE:
+    Not = true;
+    [[fallthrough]];
+  case ISD::SETLT:
+  case ISD::SETOLT:
+    // signbit(sext(x s< 0)) == signbit(x)
+    // signbit(sext(x s>= 0)) == signbit(~x)
+    Okay = ValsEq(Op1C, APInt::getZero(ValWidth));
+    // For float ops we need to ensure Op0 is de-norm. Otherwise DAZ can break
+    // this fold.
+    // NB: We only need de-norm check here, for the rest of the constants any
+    // relationship with a de-norm value and zero will be identical.
+    if (Okay && OpVT.isFloatingPoint()) {
+      // Values from integers are always normal.
+      if (Op0.getOpcode() == ISD::SINT_TO_FP ||
+          Op0.getOpcode() == ISD::UINT_TO_FP)
+        break;
+
+      // See if we can prove normal with known bits.
+      KnownBits Op0Known =
+          DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth);
+      // Negative/positive doesn't matter.
+      Op0Known.One.clearSignBit();
+      Op0Known.Zero.clearSignBit();
+
+      // Get min normal value.
+      const fltSemantics &Sem = SelectionDAG::EVTToAPFloatSemantics(OpVT);
+      KnownBits MinNormal = KnownBits::makeConstant(
+          APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
+      // Are we above de-norm range?
+      std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
+      Okay = Op0Normal.has_value() && *Op0Normal;
+    }
+    break;
+  case ISD::SETGT:
+  case ISD::SETOGT:
+    Not = true;
+    [[fallthrough]];
+  case ISD::SETLE:
+  case ISD::SETOLE:
+    // signbit(sext(x s<= -1)) == signbit(x)
+    // signbit(sext(x s> -1)) == signbit(~x)
+    Okay = ValsEq(Op1C, APInt::getAllOnes(ValWidth));
+    break;
+  case ISD::SETULT:
+    Not = true;
+    [[fallthrough]];
+  case ISD::SETUGE:
+    // signbit(sext(x u>= SIGNED_MIN)) == signbit(x)
+    // signbit(sext(x u< SIGNED_MIN)) == signbit(~x)
+    Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits);
+    break;
+  case ISD::SETULE:
+    Not = true;
+    [[fallthrough]];
+  case ISD::SETUGT:
+    // signbit(sext(x u> SIGNED_MAX)) == signbit(x)
+    // signbit(sext(x u<= SIGNED_MAX)) == signbit(~x)
+    Okay = !OpVT.isFloatingPoint() && ValsEq(Op1C, OriginalDemandedBits - 1);
+    break;
+  default:
+    break;
+  }
+
+  Okay = Not ? AllowNOT : Okay;
+  if (!Okay)
+    return SDValue();
+
+  if (!Not)
+    return Op0;
+
+  if (!OpVT.isFloatingPoint())
+    return DAG.getNOT(DL, Op0, OpVT);
+
+  // Possible TODO: We could use `fneg` to do not.
+  return SDValue();
+}
+
+static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
+                                             ISD::CondCode CC, SDValue Op0,
+                                             SDValue Op1,
+                                             const APInt &OriginalDemandedBits,
+                                             const APInt &OriginalDemandedElts,
+                                             bool AllowNOT, unsigned Depth) {
+  if (SDValue R = simplifySExtOfDecomposedSetCCImpl(
+          DAG, DL, CC, Op0, Op1, OriginalDemandedBits, OriginalDemandedElts,
+          AllowNOT, Depth))
+    return R;
+  return simplifySExtOfDecomposedSetCCImpl(
+      DAG, DL, ISD::getSetCCSwappedOperands(CC), Op1, Op0, OriginalDemandedBits,
+      OriginalDemandedElts, AllowNOT, Depth);
+}
+
 // Simplify variable target shuffle masks based on the demanded elements.
 // TODO: Handle DemandedBits in mask indices as well?
 bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetShuffle(
@@ -42342,13 +42620,26 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
     }
     break;
   }
-  case X86ISD::PCMPGT:
-    // icmp sgt(0, R) == ashr(R, BitWidth-1).
-    // iff we only need the sign bit then we can use R directly.
-    if (OriginalDemandedBits.isSignMask() &&
-        ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
-      return TLO.CombineTo(Op, Op.getOperand(1));
+  case X86ISD::PCMPGT: {
+    SDLoc DL(Op);
+    if (SDValue R = simplifySExtOfDecomposedSetCC(
+            TLO.DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
+            OriginalDemandedBits, OriginalDemandedElts,
+            /*AllowNOT*/ true, Depth))
+      return TLO.CombineTo(Op, R);
+    break;
+  }
+  case X86ISD::CMPP: {
+    SDLoc DL(Op);
+    ISD::CondCode CC = X86::getCondForCMPPImm(
+        cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
+    if (SDValue R = simplifySExtOfDecomposedSetCC(
+            TLO.DAG, DL, CC, Op.getOperand(0), Op.getOperand(1),
+            OriginalDemandedBits, OriginalDemandedElts,
+            !(TLO.LegalOperations() && TLO.LegalTypes()), Depth))
+      return TLO.CombineTo(Op, R);
     break;
+  }
   case X86ISD::MOVMSK: {
     SDValue Src = Op.getOperand(0);
     MVT SrcVT = Src.getSimpleValueType();
@@ -42532,13 +42823,25 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
     if (DemandedBits.isSignMask())
       return Op.getOperand(0);
     break;
-  case X86ISD::PCMPGT:
-    // icmp sgt(0, R) == ashr(R, BitWidth-1).
-    // iff we only need the sign bit then we can use R directly.
-    if (DemandedBits.isSignMask() &&
-        ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode()))
-      return Op.getOperand(1);
+  case X86ISD::PCMPGT: {
+    SDLoc DL(Op);
+    if (SDValue R = simplifySExtOfDecomposedSetCC(
+            DAG, DL, ISD::SETGT, Op.getOperand(0), Op.getOperand(1),
+            DemandedBits, DemandedElts, /*AllowNOT*/ false, Depth))
+      return R;
+    break;
+  }
+  case X86ISD::CMPP: {
+    SDLoc DL(Op);
+    ISD::CondCode CC = X86::getCondForCMPPImm(
+        cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue());
+    if (SDValue R = simplifySExtOfDecomposedSetCC(DAG, DL, CC, Op.getOperand(0),
+                                                  Op.getOperand(1),
+                                                  DemandedBits, DemandedElts,
+                                                  /*AllowNOT*/ false, Depth))
+      return R;
     break;
+  }
   case X86ISD::BLENDV: {
     // BLENDV: Cond (MSB) ? LHS : RHS
     SDValue Cond = Op.getOperand(0);
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 3f0557e651f89b..2e331efd9c3d0b 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -3349,6 +3349,46 @@ unsigned X86::getVPCMPImmForCond(ISD::CondCode CC) {
   }
 }
 
+ISD::CondCode X86::getCondForCMPPImm(unsigned Imm) {
+  assert(Imm <= 0x1f && "Invalid CMPP Imm");
+  switch (Imm & 0xf) {
+  default:
+    llvm_unreachable("Invalid CMPP Imm");
+  case 0:
+    return ISD::SETOEQ;
+  case 1:
+    return ISD::SETOLT;
+  case 2:
+    return ISD::SETOLE;
+  case 3:
+    return ISD::SETUO;
+  case 4:
+    return ISD::SETUNE;
+  case 5:
+    return ISD::SETUGE;
+  case 6:
+    return ISD::SETUGT;
+  case 7:
+    return ISD::SETO;
+  case 8:
+    return ISD::SETUEQ;
+  case 9:
+    return ISD::SETULT;
+  case 10:
+    return ISD::SETULE;
+  case 11:
+    return ISD::SETFALSE;
+  case 12:
+    return ISD::SETONE;
+  case 13:
+    return ISD::SETOGE;
+  case 14:
+    return ISD::SETOGT;
+  case 15:
+    return ISD::SETTRUE;
+  }
+}
+
 /// Get the VPCMP immediate if the operands are swapped.
 unsigned X86::getSwappedVPCMPImm(unsigned Imm) {
   switch (Imm) {
diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h
index 0e5fcbeda08f79..4569a74aab54eb 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.h
+++ b/llvm/lib/Target/X86/X86InstrInfo.h
@@ -68,6 +68,9 @@ CondCode GetOppositeBranchCondition(CondCode CC);
 /// Get the VPCMP immediate for the given condition.
 unsigned getVPCMPImmForCond(ISD::CondCode CC);
 
+/// Get the CondCode from a CMPP immediate.
+ISD::CondCode getCondForCMPPImm(unsigned Imm);
+
 /// Get the VPCMP immediate if the opcodes are swapped.
 unsigned getSwappedVPCMPImm(unsigned Imm);
 
diff --git a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
index 6255621d870e12..eef2b3db5d694e 100644
--- a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
+++ b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll
@@ -256,12 +256,9 @@ define <8 x i32> @ext_i8_8i32(i8 %a0) {
 ; AVX1-NEXT:    vmovd %edi, %xmm0
 ; AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    vmovaps {{.*#+}} ymm1 = [1,2,4,8,16,32,64,128]
-; AVX1-NEXT:    vandps %ymm1, %ymm0, %ymm0
-; AVX1-NEXT:    vpcmpeqd %xmm1, %xmm0, %xmm1
-; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm0
-; AVX1-NEXT:    vpcmpeqd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm1, %ymm0
+; AVX1-NEXT:    vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
+; AVX1-NEXT:    vcvtdq2ps %ymm0, %ymm0
+; AVX1-NEXT:    vcmpeqps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: ext_i8_8i32:
@@ -487,18 +484,12 @@ define <16 x i32> @ext_i16_16i32(i16 %a0) {
 ; AVX1-NEXT:    vmovd %edi, %xmm0
 ; AVX1-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,0,0,0]
 ; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm1
-; AVX1-NEXT:    vmovaps {{.*#+}} ymm0 = [1,2,4,8,16,32,64,128]
-; AVX1-NEXT:    vandps %ymm0, %ymm1, %ymm2
-; AVX1-NEXT:    vpcmpeqd %xmm0, %xmm2, %xmm0
-; AVX1-NEXT:    vextractf128 $1, %ymm2, %xmm2
-; AVX1-NEXT:    vpcmpeqd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm0, %ymm0
-; AVX1-NEXT:    vmovaps {{.*#+}} ymm2 = [256,512,1024,2048,4096,8192,16384,32768]
-; AVX1-NEXT:    vandps %ymm2, %ymm1, %ymm1
-; AVX1-NEXT:    vpcmpeqd %xmm2, %xmm1, %xmm2
-; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm1
-; AVX1-NEXT:    vpcmpeqd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
+; AVX1-NEXT:    vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm0
+; AVX1-NEXT:    vcvtdq2ps %ymm0, %ymm0
+; AVX1-NEXT:    vcmpeqps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0
+; AVX1-NEXT:    vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1
+; AVX1-NEXT:    vcvtdq2ps %ymm1, %ymm1
+; AVX1-NEXT:    vcmpeqps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm1, %ymm1
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: ext_i16_16i32:
diff --git a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-zext.ll b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-zext.ll
index d2794df731b65d..5c810797bd2b75 100644
--- a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-zext.ll
+++ b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-zext.ll
@@ -320,12 +320,9 @@ define <8 x i32> @ext_i8_8i32(i8 %a0) {
 ; AVX1-NEXT:    vm...
[truncated]

@goldsteinn goldsteinn changed the title goldsteinn/simplify x86 cmp [X86] Improve helper for simplifying demanded bits of compares Mar 7, 2024
@goldsteinn
Copy link
Contributor Author

This is split from: #82290

The reason this one comes second, is that we aren't really able to test the cmpp simplifications without #82290 as we never get the necessary FP flags without it.

The plan is to get #82290 in, then rebase this and submit it to clean up the regressions that #82290 causes without it.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minors

APFloat::getSmallestNormalized(Sem).bitcastToAPInt());
// Are we above de-norm range?
std::optional<bool> Op0Normal = KnownBits::uge(Op0Known, MinNormal);
Okay = Op0Normal.has_value() && *Op0Normal;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay = Op0Normal.value_or(false);

@@ -41293,6 +41293,154 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// Simplify a decomposed (sext (setcc)). Assumes prior check that
// bitwidth(sext)==bitwidth(setcc operands).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this could technically be moved to TargetLowering.cpp so any ZeroOrNegativeOneBooleanContent target could use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill port

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there where no test changes outside of X86.

I would prefer to keep in x86 for now, im a bit worried about whether the float semantics will translate to other targets.
Ill leave a note in simplfyDemandedBits that if we want to extend ISD::SETCC it may be worth porting.

return SDValue();
}

static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const SDLoc &DL

// Simplify a decomposed (sext (setcc)). Assumes prior check that
// bitwidth(sext)==bitwidth(setcc operands).
static SDValue simplifySExtOfDecomposedSetCCImpl(
SelectionDAG &DAG, SDLoc &DL, ISD::CondCode CC, SDValue Op0, SDValue Op1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const SDLoc &DL

break;
}

Okay = Not ? AllowNOT : Okay;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if Okay == false && AllowNOT == true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bad bug, patch is no longer so sexy.

@goldsteinn goldsteinn force-pushed the goldsteinn/simplify-x86-cmp branch 3 times, most recently from ccc6087 to f6b0d17 Compare March 15, 2024 18:43
@goldsteinn
Copy link
Contributor Author

Rebase

return SDValue();
}

static SDValue simplifySExtOfDecomposedSetCC(SelectionDAG &DAG, SDLoc &DL,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const SDLoc &

Fixes: llvm#82242

The idea is that AVX doesn't support comparisons for `v8i32` so it
splits the comparison into 2x `v4i32` comparisons + reconstruction of
the `v8i32`.

By converting to a float, we can handle the comparison with 1/2
instructions (1 if we can `bitcast`, 2 if we need to cast with
`sitofp`).

The Proofs: https://alive2.llvm.org/ce/z/AJDdQ8
Timeout, but they can be reproduced locally.
We currently only handle a single case for `pcmpgt`. This patch
extends that to work for `cmpp` and handles comparitors more
generically.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants