Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueTracking: Identify implied fp classes by general fcmp #65983

Closed
wants to merge 17 commits into from

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Sep 11, 2023

Previously we could recognize exact class tests performed by an fcmp with special values (0s, infs and smallest normal). Expand this to recognize the implied classes by a compare with a general constant. e.g. fcmp ogt x, 1 implies positive and non-0.

The API should be better merged with fcmpToClassTest but that made the diff way bigger, will try to do that in a future patch.

Previously we could recognize exact class tests performed by
an fcmp with special values (0s, infs and smallest normal).
Expand this to recognize the implied classes by a compare with a general
constant. e.g. fcmp ogt x, 1 implies positive and non-0.

The API should be better merged with fcmpToClassTest but that
made the diff way bigger, will try to do that in a future
patch.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-llvm-analysis

Changes

Previously we could recognize exact class tests performed by an fcmp with special values (0s, infs and smallest normal). Expand this to recognize the implied classes by a compare with a general constant. e.g. fcmp ogt x, 1 implies positive and non-0.

The API should be better merged with fcmpToClassTest but that made the diff way bigger, will try to do that in a future patch.

Patch is 114.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65983.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+21)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+159-19)
  • (modified) llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll (+160-160)
  • (modified) llvm/test/Transforms/InstSimplify/assume-fcmp-constant-implies-class.ll (+90-180)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 695f2fecae885b7..6a2fac1961070e4 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -235,6 +235,27 @@ std::pair fcmpToClassTest(CmpInst::Predicate Pred,
                                                 const APFloat *ConstRHS,
                                                 bool LookThroughSrc = true);
 
+/// Compute the possible floating-point classes that \p LHS could be based on an
+/// fcmp returning true. Returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
+///
+/// If the compare returns an exact class test, ClassesIfTrue == ~ClassesIfFalse
+///
+/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
+/// only succeed for a test of x > 0 implies positive, but not x > 1).
+///
+/// If \p LookThroughSrc is true, consider the input value when computing the
+/// mask. This may look through sign bit operations.
+///
+/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
+/// element will always be LHS.
+///
+std::tuple
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                 const APFloat *ConstRHS, bool LookThroughSrc = true);
+std::tuple
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                 Value *RHS, bool LookThroughSrc = true);
+
 struct KnownFPClass {
   /// Floating-point classes the value could be one of.
   FPClassTest KnownFPClasses = fcAllFlags;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c4153b824c37e0a..f7ae8af36199f80 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4217,6 +4217,140 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
   return {Src, Mask};
 }
 
+std::tuple
+llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                       const APFloat *ConstRHS, bool LookThroughSrc) {
+  auto [Val, ClassMask] =
+      fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
+  if (Val)
+    return {Val, ClassMask, ~ClassMask};
+
+  FPClassTest RHSClass = ConstRHS->classify();
+  assert((RHSClass == fcPosNormal || RHSClass == fcNegNormal ||
+          RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal) &&
+         "should have been recognized as an exact class test");
+
+  const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
+  const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
+
+  assert(IsNegativeRHS == ConstRHS->isNegative());
+  assert(IsPositiveRHS == !ConstRHS->isNegative());
+
+  Value *Src = LHS;
+  const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
+
+  if (IsFabs)
+    RHSClass = llvm::fneg(RHSClass);
+
+  if (Pred == FCmpInst::FCMP_OEQ)
+    return {Src, RHSClass, fcAllFlags};
+
+  if (Pred == FCmpInst::FCMP_UEQ) {
+    FPClassTest Class = RHSClass | fcNan;
+    return {Src, Class, ~Class};
+  }
+
+  if (Pred == FCmpInst::FCMP_ONE)
+    return {Src, ~fcNan, RHSClass};
+
+  if (Pred == FCmpInst::FCMP_UNE)
+    return {Src, fcAllFlags, RHSClass | fcNan};
+
+  if (IsNegativeRHS) {
+    // TODO: Handle fneg(fabs)
+    if (IsFabs) {
+      // fabs(x) o> -k -> fcmp ord x, x
+      // fabs(x) u> -k -> true
+      // fabs(x) o< -k -> false
+      // fabs(x) u< -k -> fcmp uno x, x
+      switch (Pred) {
+      case FCmpInst::FCMP_OGT:
+      case FCmpInst::FCMP_OGE:
+        return {Src, ~fcNan, fcNan};
+      case FCmpInst::FCMP_UGT:
+      case FCmpInst::FCMP_UGE:
+        return {Src, fcAllFlags, fcNone};
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_OLE:
+        return {Src, fcNone, fcAllFlags};
+      case FCmpInst::FCMP_ULT:
+      case FCmpInst::FCMP_ULE:
+        return {Src, fcNan, ~fcNan};
+      default:
+        break;
+      }
+
+      return {nullptr, fcAllFlags, fcAllFlags};
+    }
+
+    FPClassTest ClassesLE = fcNegInf | fcNegNormal;
+    FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
+
+    if (ConstRHS->isDenormal())
+      ClassesLE |= fcNegSubnormal;
+    else
+      ClassesGE |= fcNegNormal;
+
+    switch (Pred) {
+    case FCmpInst::FCMP_OGT:
+    case FCmpInst::FCMP_OGE:
+      return {Src, ClassesGE, ~ClassesGE | RHSClass};
+    case FCmpInst::FCMP_UGT:
+    case FCmpInst::FCMP_UGE:
+      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
+    case FCmpInst::FCMP_OLT:
+    case FCmpInst::FCMP_OLE:
+      return {Src, ClassesLE, ~ClassesLE | RHSClass};
+    case FCmpInst::FCMP_ULT:
+    case FCmpInst::FCMP_ULE:
+      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
+    default:
+      break;
+    }
+  } else if (IsPositiveRHS) {
+    FPClassTest ClassesGE = fcPosNormal | fcPosInf;
+    FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosNormal;
+    if (ConstRHS->isDenormal())
+      ClassesGE |= fcPosNormal;
+    else
+      ClassesLE |= fcPosSubnormal;
+
+    FPClassTest FalseClasses = RHSClass;
+    if (IsFabs) {
+      ClassesGE = llvm::fabs(ClassesGE);
+      ClassesLE = llvm::fabs(ClassesLE);
+    }
+
+    switch (Pred) {
+    case FCmpInst::FCMP_OGT:
+    case FCmpInst::FCMP_OGE:
+      return {Src, ClassesGE, ~ClassesGE | FalseClasses};
+    case FCmpInst::FCMP_UGT:
+    case FCmpInst::FCMP_UGE:
+      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | FalseClasses};
+    case FCmpInst::FCMP_OLT:
+    case FCmpInst::FCMP_OLE:
+      return {Src, ClassesLE, ~ClassesLE | FalseClasses};
+    case FCmpInst::FCMP_ULT:
+    case FCmpInst::FCMP_ULE:
+      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | FalseClasses};
+    default:
+      break;
+    }
+  }
+
+  return {nullptr, fcAllFlags, fcAllFlags};
+}
+
+std::tuple
+llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                       Value *RHS, bool LookThroughSrc) {
+  const APFloat *ConstRHS;
+  if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
+    return {nullptr, fcAllFlags, fcNone};
+  return fcmpImpliesClass(Pred, F, LHS, ConstRHS, LookThroughSrc);
+}
+
 static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
                                                   const SimplifyQuery &Q) {
   FPClassTest KnownFromAssume = fcAllFlags;
@@ -4241,18 +4375,21 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
     Value *LHS, *RHS;
     uint64_t ClassVal = 0;
     if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
-      auto [TestedValue, TestedMask] =
-          fcmpToClassTest(Pred, *F, LHS, RHS, true);
-      // First see if we can fold in fabs/fneg into the test.
-      if (TestedValue == V)
-        KnownFromAssume &= TestedMask;
-      else {
-        // Try again without the lookthrough if we found a different source
-        // value.
-        auto [TestedValue, TestedMask] =
-            fcmpToClassTest(Pred, *F, LHS, RHS, false);
-        if (TestedValue == V)
-          KnownFromAssume &= TestedMask;
+      const APFloat *CRHS;
+      if (match(RHS, m_APFloat(CRHS))) {
+        // First see if we can fold in fabs/fneg into the test.
+        auto [CmpVal, MaskIfTrue, MaskIfFalse] =
+            fcmpImpliesClass(Pred, *F, LHS, CRHS, true);
+        if (CmpVal == V)
+          KnownFromAssume &= MaskIfTrue;
+        else {
+          // Try again without the lookthrough if we found a different source
+          // value.
+          auto [CmpVal, MaskIfTrue, MaskIfFalse] =
+              fcmpImpliesClass(Pred, *F, LHS, CRHS, false);
+          if (CmpVal == V)
+            KnownFromAssume &= MaskIfTrue;
+        }
       }
     } else if (match(I->getArgOperand(0),
                      m_Intrinsic(
@@ -4400,7 +4537,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
     FPClassTest FilterRHS = fcAllFlags;
 
     Value *TestedValue = nullptr;
-    FPClassTest TestedMask = fcNone;
+    FPClassTest MaskIfTrue = fcAllFlags;
+    FPClassTest MaskIfFalse = fcAllFlags;
     uint64_t ClassVal = 0;
     const Function *F = cast(Op)->getFunction();
     CmpInst::Predicate Pred;
@@ -4412,20 +4550,22 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
       // TODO: In some degenerate cases we can infer something if we try again
       // without looking through sign operations.
       bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS;
-      std::tie(TestedValue, TestedMask) =
-          fcmpToClassTest(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
+      std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
+          fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
     } else if (match(Cond,
                      m_Intrinsic(
                          m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
-      TestedMask = static_cast(ClassVal);
+      FPClassTest TestedMask = static_cast(ClassVal);
+      MaskIfTrue = TestedMask;
+      MaskIfFalse = ~TestedMask;
     }
 
     if (TestedValue == LHS) {
       // match !isnan(x) ? x : y
-      FilterLHS = TestedMask;
-    } else if (TestedValue == RHS) {
+      FilterLHS = MaskIfTrue;
+    } else if (TestedValue == RHS) { // && IsExactClass
       // match !isnan(x) ? y : x
-      FilterRHS = ~TestedMask;
+      FilterRHS = MaskIfFalse;
     }
 
     KnownFPClass Known2;
diff --git a/llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll b/llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll
index 396b8c84fc898c9..b3dce48868d7dbf 100644
--- a/llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll
+++ b/llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll
@@ -11,7 +11,7 @@ declare void @llvm.assume(i1 noundef)
 
 ; can't be +inf
 define float @clamp_is_ogt_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ogt_1_to_1(
+; CHECK-LABEL: define nofpclass(pinf) float @clamp_is_ogt_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2:[0-9]+]] {
 ; CHECK-NEXT:    [[IS_OGT_1:%.*]] = fcmp ogt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OGT_1]], float 1.000000e+00, float [[ARG]]
@@ -23,7 +23,7 @@ define float @clamp_is_ogt_1_to_1(float %arg) {
 }
 
 define float @clamp_is_ogt_1_to_1_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ogt_1_to_1_commute(
+; CHECK-LABEL: define nofpclass(pinf) float @clamp_is_ogt_1_to_1_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULE_1:%.*]] = fcmp ule float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULE_1]], float [[ARG]], float 1.000000e+00
@@ -36,7 +36,7 @@ define float @clamp_is_ogt_1_to_1_commute(float %arg) {
 
 ; can't be +inf or nan
 define float @clamp_is_ugt_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ugt_1_to_1(
+; CHECK-LABEL: define nofpclass(nan pinf) float @clamp_is_ugt_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_UGT_1:%.*]] = fcmp ugt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_UGT_1]], float 1.000000e+00, float [[ARG]]
@@ -49,7 +49,7 @@ define float @clamp_is_ugt_1_to_1(float %arg) {
 
 ; can't be +inf or nan
 define float @clamp_is_ugt_1_to_1_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ugt_1_to_1_commute(
+; CHECK-LABEL: define nofpclass(nan pinf) float @clamp_is_ugt_1_to_1_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OLE_1:%.*]] = fcmp ole float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OLE_1]], float [[ARG]], float 1.000000e+00
@@ -62,7 +62,7 @@ define float @clamp_is_ugt_1_to_1_commute(float %arg) {
 
 ; can't be +inf
 define float @clamp_is_oge_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_oge_1_to_1(
+; CHECK-LABEL: define nofpclass(pinf) float @clamp_is_oge_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OGE_1:%.*]] = fcmp oge float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OGE_1]], float 1.000000e+00, float [[ARG]]
@@ -74,7 +74,7 @@ define float @clamp_is_oge_1_to_1(float %arg) {
 }
 
 define float @clamp_is_oge_1_to_1_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_oge_1_to_1_commute(
+; CHECK-LABEL: define nofpclass(pinf) float @clamp_is_oge_1_to_1_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULT_1:%.*]] = fcmp ult float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULT_1]], float [[ARG]], float 1.000000e+00
@@ -87,7 +87,7 @@ define float @clamp_is_oge_1_to_1_commute(float %arg) {
 
 ; can't be +inf or nan
 define float @clamp_is_uge_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_uge_1_to_1(
+; CHECK-LABEL: define nofpclass(nan pinf) float @clamp_is_uge_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_UGT_1:%.*]] = fcmp uge float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_UGT_1]], float 1.000000e+00, float [[ARG]]
@@ -100,7 +100,7 @@ define float @clamp_is_uge_1_to_1(float %arg) {
 
 ; can't be negative, zero, or denormal
 define float @clamp_is_olt_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_olt_1_to_1(
+; CHECK-LABEL: define nofpclass(ninf zero sub nnorm) float @clamp_is_olt_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OLT_1:%.*]] = fcmp olt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OLT_1]], float 1.000000e+00, float [[ARG]]
@@ -113,7 +113,7 @@ define float @clamp_is_olt_1_to_1(float %arg) {
 
 ; can't be negative, zero, or denormal
 define float @clamp_is_olt_1_to_1_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_olt_1_to_1_commute(
+; CHECK-LABEL: define nofpclass(ninf zero sub nnorm) float @clamp_is_olt_1_to_1_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_UGE_1:%.*]] = fcmp uge float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_UGE_1]], float [[ARG]], float 1.000000e+00
@@ -126,7 +126,7 @@ define float @clamp_is_olt_1_to_1_commute(float %arg) {
 
 ; can't be negative or zero, nan or denormal
 define float @clamp_is_ult_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ult_1_to_1(
+; CHECK-LABEL: define nofpclass(nan ninf zero sub nnorm) float @clamp_is_ult_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULT_1:%.*]] = fcmp ult float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULT_1]], float 1.000000e+00, float [[ARG]]
@@ -139,7 +139,7 @@ define float @clamp_is_ult_1_to_1(float %arg) {
 
 ; can't be negative or zero, nan or denormal
 define float @clamp_is_ult_1_to_1_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ult_1_to_1_commute(
+; CHECK-LABEL: define nofpclass(nan ninf zero sub nnorm) float @clamp_is_ult_1_to_1_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OGE_1:%.*]] = fcmp oge float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OGE_1]], float [[ARG]], float 1.000000e+00
@@ -152,7 +152,7 @@ define float @clamp_is_ult_1_to_1_commute(float %arg) {
 
 ; can't be negative, zero or denormal
 define float @clamp_is_ole_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ole_1_to_1(
+; CHECK-LABEL: define nofpclass(ninf zero sub nnorm) float @clamp_is_ole_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OLE_1:%.*]] = fcmp ole float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OLE_1]], float 1.000000e+00, float [[ARG]]
@@ -165,7 +165,7 @@ define float @clamp_is_ole_1_to_1(float %arg) {
 
 ; can't be negative or zero, nan or denormal
 define float @clamp_is_ule_1_to_1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ule_1_to_1(
+; CHECK-LABEL: define nofpclass(nan ninf zero sub nnorm) float @clamp_is_ule_1_to_1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULE_1:%.*]] = fcmp ule float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULE_1]], float 1.000000e+00, float [[ARG]]
@@ -178,7 +178,7 @@ define float @clamp_is_ule_1_to_1(float %arg) {
 
 ; can't be negative or denormal
 define float @clamp_is_olt_1_to_0(float %arg) {
-; CHECK-LABEL: define float @clamp_is_olt_1_to_0(
+; CHECK-LABEL: define nofpclass(ninf nzero sub nnorm) float @clamp_is_olt_1_to_0(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OLT_1:%.*]] = fcmp olt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OLT_1]], float 0.000000e+00, float [[ARG]]
@@ -191,7 +191,7 @@ define float @clamp_is_olt_1_to_0(float %arg) {
 
 ; can't be negative, nan or denormal
 define float @clamp_is_ult_1_to_0(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ult_1_to_0(
+; CHECK-LABEL: define nofpclass(ninf nzero sub nnorm) float @clamp_is_ult_1_to_0(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULT_1:%.*]] = fcmp olt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULT_1]], float 0.000000e+00, float [[ARG]]
@@ -204,7 +204,7 @@ define float @clamp_is_ult_1_to_0(float %arg) {
 
 ; can't be negative or denormal
 define float @clamp_is_ole_1_to_0(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ole_1_to_0(
+; CHECK-LABEL: define nofpclass(ninf nzero sub nnorm) float @clamp_is_ole_1_to_0(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OLE_1:%.*]] = fcmp ole float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OLE_1]], float 0.000000e+00, float [[ARG]]
@@ -217,7 +217,7 @@ define float @clamp_is_ole_1_to_0(float %arg) {
 
 ; can't be negative or denormal
 define float @clamp_is_ole_1_to_0_commute(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ole_1_to_0_commute(
+; CHECK-LABEL: define nofpclass(ninf nzero sub nnorm) float @clamp_is_ole_1_to_0_commute(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_UGT_1:%.*]] = fcmp ugt float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_UGT_1]], float [[ARG]], float 0.000000e+00
@@ -230,7 +230,7 @@ define float @clamp_is_ole_1_to_0_commute(float %arg) {
 
 ; can't be negative or denormal
 define float @clamp_is_ule_1_to_0(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ule_1_to_0(
+; CHECK-LABEL: define nofpclass(ninf nzero sub nnorm) float @clamp_is_ule_1_to_0(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_ULE_1:%.*]] = fcmp ole float [[ARG]], 1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_ULE_1]], float 0.000000e+00, float [[ARG]]
@@ -243,7 +243,7 @@ define float @clamp_is_ule_1_to_0(float %arg) {
 
 ; can't be positive, zero or denormal
 define float @clamp_is_ogt_neg1_to_neg1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ogt_neg1_to_neg1(
+; CHECK-LABEL: define nofpclass(pinf zero sub pnorm) float @clamp_is_ogt_neg1_to_neg1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_OGT_NEG1:%.*]] = fcmp ogt float [[ARG]], -1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_OGT_NEG1]], float -1.000000e+00, float [[ARG]]
@@ -256,7 +256,7 @@ define float @clamp_is_ogt_neg1_to_neg1(float %arg) {
 
 ; can't be positive, zero, nan or denormal
 define float @clamp_is_ugt_neg1_to_neg1(float %arg) {
-; CHECK-LABEL: define float @clamp_is_ugt_neg1_to_neg1(
+; CHECK-LABEL: define nofpclass(nan pinf zero sub pnorm) float @clamp_is_ugt_neg1_to_neg1(
 ; CHECK-SAME: float [[ARG:%.*]]) #[[ATTR2]] {
 ; CHECK-NEXT:    [[IS_UGT_NEG1:%.*]] = fcmp ugt float [[ARG]], -1.000000e+00
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[IS_UGT_NEG1]], float -1.000000e+00, float [[ARG]]
@@ -269,7 +269,7 @@ define float @clamp_is_ugt_neg1_to_neg1(float ...

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-transforms

Changes Previously we could recognize exact class tests performed by an fcmp with special values (0s, infs and smallest normal). Expand this to recognize the implied classes by a compare with a general constant. e.g. fcmp ogt x, 1 implies positive and non-0.

The API should be better merged with fcmpToClassTest but that made the diff way bigger, will try to do that in a future patch.

Patch is 126.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65983.diff

6 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+34)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+240-67)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+1-1)
  • (modified) llvm/test/Transforms/Attributor/nofpclass-implied-by-fcmp.ll (+160-160)
  • (modified) llvm/test/Transforms/InstSimplify/assume-fcmp-constant-implies-class.ll (+90-180)
  • (modified) llvm/unittests/Analysis/ValueTrackingTest.cpp (+4-4)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 695f2fecae885b7..bc571a702f717d4 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -235,6 +235,40 @@ std::pair&lt;Value *, FPClassTest&gt; fcmpToClassTest(CmpInst::Predicate Pred,
                                                 const APFloat *ConstRHS,
                                                 bool LookThroughSrc = true);
 
+/// Compute the possible floating-point classes that \p LHS could be based on
+/// fcmp \Pred \p LHS, \p RHS.
+///
+/// Returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
+///
+/// If the compare returns an exact class test, ClassesIfTrue == ~ClassesIfFalse
+///
+/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
+/// only succeed for a test of x &gt; 0 implies positive, but not x &gt; 1).
+///
+/// If \p LookThroughSrc is true, consider the input value when computing the
+/// mask. This may look through sign bit operations.
+///
+/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
+/// element will always be LHS.
+///
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                 Value *RHS, bool LookThroughSrc = true);
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                 FPClassTest RHS, bool LookThroughSrc = true);
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                 const APFloat &amp;RHS, bool LookThroughSrc = true);
+
+#if 0
+inline std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                 const APFloat &amp;RHS, bool LookThroughSrc = true) {
+  return fcmpImpliesClass(Pred, F, LHS, RHS.classify(), LookThroughSrc);
+}
+#endif
+
 struct KnownFPClass {
   /// Floating-point classes the value could be one of.
   FPClassTest KnownFPClasses = fcAllFlags;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c4153b824c37e0a..ccacc75d50c67af 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4016,67 +4016,106 @@ std::pair&lt;Value *, FPClassTest&gt; llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
 std::pair&lt;Value *, FPClassTest&gt;
 llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
                       const APFloat *ConstRHS, bool LookThroughSrc) {
+
+  auto [Src, ClassIfTrue, ClassIfFalse] = fcmpImpliesClass(Pred,
+                                                           F,
+                                                           LHS,
+                                                           *ConstRHS, LookThroughSrc);
+  if (Src &amp;&amp; ClassIfTrue == ~ClassIfFalse)
+    return {Src, ClassIfTrue};
+  return {nullptr, fcAllFlags};
+}
+
+/// Return the return value for fcmpImpliesClass for a compare that produces an
+/// exact class test.
+static std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+exactClass(Value *V, FPClassTest M) {
+  return {V, M, ~M};
+}
+
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                       FPClassTest RHSClass, bool LookThroughSrc) {
+  assert(RHSClass != fcNone);
+
+  FPClassTest OrigClass = RHSClass;
+
+  Value *Src = LHS;
+  const bool IsNegativeRHS = (RHSClass &amp; fcNegative) == RHSClass;
+  const bool IsPositiveRHS = (RHSClass &amp; fcPositive) == RHSClass;
+  const bool IsNaN = (RHSClass &amp; ~fcNan) == fcNone;
+
+  if (IsNaN) {
+    // fcmp o__ x, nan -&gt; false
+    // fcmp u__ x, nan -&gt; true
+    return exactClass(Src, CmpInst::isOrdered(Pred) ? fcNone : fcAllFlags);
+  }
+
   // fcmp ord x, zero|normal|subnormal|inf -&gt; ~fcNan
-  if (Pred == FCmpInst::FCMP_ORD &amp;&amp; !ConstRHS-&gt;isNaN())
-    return {LHS, ~fcNan};
+  if (Pred == FCmpInst::FCMP_ORD)
+    return {Src, ~fcNan, fcNan};
 
   // fcmp uno x, zero|normal|subnormal|inf -&gt; fcNan
-  if (Pred == FCmpInst::FCMP_UNO &amp;&amp; !ConstRHS-&gt;isNaN())
-    return {LHS, fcNan};
+  if (Pred == FCmpInst::FCMP_UNO)
+    return {Src, fcNan, ~fcNan};
 
-  if (ConstRHS-&gt;isZero()) {
+  const bool IsFabs = LookThroughSrc &amp;&amp; match(LHS, m_FAbs(m_Value(Src)));
+  if (IsFabs)
+    RHSClass = llvm::fneg(RHSClass);
+
+  const bool IsZero = (RHSClass &amp; fcZero) == RHSClass;
+  if (IsZero) {
     // Compares with fcNone are only exactly equal to fcZero if input denormals
     // are not flushed.
     // TODO: Handle DAZ by expanding masks to cover subnormal cases.
     if (Pred != FCmpInst::FCMP_ORD &amp;&amp; Pred != FCmpInst::FCMP_UNO &amp;&amp;
         !inputDenormalIsIEEE(F, LHS-&gt;getType()))
-      return {nullptr, fcNone};
+      return {nullptr, fcAllFlags, fcAllFlags};
 
     switch (Pred) {
     case FCmpInst::FCMP_OEQ: // Match x == 0.0
-      return {LHS, fcZero};
+      return exactClass(Src, fcZero);
     case FCmpInst::FCMP_UEQ: // Match isnan(x) || (x == 0.0)
-      return {LHS, fcZero | fcNan};
+      return exactClass(Src, fcZero | fcNan);
     case FCmpInst::FCMP_UNE: // Match (x != 0.0)
-      return {LHS, ~fcZero};
+      return exactClass(Src, ~fcZero);
     case FCmpInst::FCMP_ONE: // Match !isnan(x) &amp;&amp; x != 0.0
-      return {LHS, ~fcNan &amp; ~fcZero};
+      return exactClass(Src, ~fcNan &amp; ~fcZero);
     case FCmpInst::FCMP_ORD:
       // Canonical form of ord/uno is with a zero. We could also handle
       // non-canonical other non-NaN constants or LHS == RHS.
-      return {LHS, ~fcNan};
+      return exactClass(Src, ~fcNan);
     case FCmpInst::FCMP_UNO:
-      return {LHS, fcNan};
+      return exactClass(Src, fcNan);
     case FCmpInst::FCMP_OGT: // x &gt; 0
-      return {LHS, fcPosSubnormal | fcPosNormal | fcPosInf};
+      return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf);
     case FCmpInst::FCMP_UGT: // isnan(x) || x &gt; 0
-      return {LHS, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan};
+      return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan);
     case FCmpInst::FCMP_OGE: // x &gt;= 0
-      return {LHS, fcPositive | fcNegZero};
+      return exactClass(Src, fcPositive | fcNegZero);
     case FCmpInst::FCMP_UGE: // isnan(x) || x &gt;= 0
-      return {LHS, fcPositive | fcNegZero | fcNan};
+      return exactClass(Src, fcPositive | fcNegZero | fcNan);
     case FCmpInst::FCMP_OLT: // x &lt; 0
-      return {LHS, fcNegSubnormal | fcNegNormal | fcNegInf};
+      return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf);
     case FCmpInst::FCMP_ULT: // isnan(x) || x &lt; 0
-      return {LHS, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan};
+      return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan);
     case FCmpInst::FCMP_OLE: // x &lt;= 0
-      return {LHS, fcNegative | fcPosZero};
+      return exactClass(Src, fcNegative | fcPosZero);
     case FCmpInst::FCMP_ULE: // isnan(x) || x &lt;= 0
-      return {LHS, fcNegative | fcPosZero | fcNan};
+      return exactClass(Src, fcNegative | fcPosZero | fcNan);
     default:
       break;
     }
 
-    return {nullptr, fcNone};
+    return {nullptr, fcAllFlags, fcAllFlags};
   }
 
-  Value *Src = LHS;
-  const bool IsFabs = LookThroughSrc &amp;&amp; match(LHS, m_FAbs(m_Value(Src)));
+  const bool IsDenormalRHS = (RHSClass &amp; fcSubnormal) == RHSClass;
 
-  // Compute the test mask that would return true for the ordered comparisons.
-  FPClassTest Mask;
+  const bool IsInf = (RHSClass &amp; fcInf) == RHSClass;
+  if (IsInf) {
+    FPClassTest Mask = fcAllFlags;
 
-  if (ConstRHS-&gt;isInfinity()) {
     switch (Pred) {
     case FCmpInst::FCMP_OEQ:
     case FCmpInst::FCMP_UNE: {
@@ -4091,8 +4130,7 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
       //   fcmp une fabs(x), +inf -&gt; is_fpclass x, ~fcInf
       //   fcmp une x, -inf -&gt; is_fpclass x, ~fcNegInf
       //   fcmp une fabs(x), -inf -&gt; is_fpclass x, fcAllFlags -&gt; true
-
-      if (ConstRHS-&gt;isNegative()) {
+      if (IsNegativeRHS) {
         Mask = fcNegInf;
         if (IsFabs)
           Mask = fcNone;
@@ -4101,7 +4139,6 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
         if (IsFabs)
           Mask |= fcNegInf;
       }
-
       break;
     }
     case FCmpInst::FCMP_ONE:
@@ -4116,7 +4153,7 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
       //   fcmp ueq (fabs x), +inf -&gt; is_fpclass x, fcInf|fcNan
       //   fcmp ueq x, -inf -&gt; is_fpclass x, fcNegInf|fcNan
       //   fcmp ueq fabs(x), -inf -&gt; is_fpclass x, fcNan
-      if (ConstRHS-&gt;isNegative()) {
+      if (IsNegativeRHS) {
         Mask = ~fcNegInf &amp; ~fcNan;
         if (IsFabs)
           Mask = ~fcNan;
@@ -4130,7 +4167,7 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
     }
     case FCmpInst::FCMP_OLT:
     case FCmpInst::FCMP_UGE: {
-      if (ConstRHS-&gt;isNegative()) {
+      if (IsNegativeRHS) {
         // No value is ordered and less than negative infinity.
         // All values are unordered with or at least negative infinity.
         // fcmp olt x, -inf -&gt; false
@@ -4150,8 +4187,8 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
     }
     case FCmpInst::FCMP_OGE:
     case FCmpInst::FCMP_ULT: {
-      if (ConstRHS-&gt;isNegative()) // TODO
-        return {nullptr, fcNone};
+      if (IsNegativeRHS) // TODO
+        return {nullptr, fcAllFlags, fcAllFlags};
 
       // fcmp oge fabs(x), +inf -&gt; fcInf
       // fcmp oge x, +inf -&gt; fcPosInf
@@ -4164,17 +4201,138 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
     }
     case FCmpInst::FCMP_OGT:
     case FCmpInst::FCMP_ULE: {
-      if (ConstRHS-&gt;isNegative())
-        return {nullptr, fcNone};
+      if (IsNegativeRHS)
+        return {nullptr, fcAllFlags, fcAllFlags};
 
       // No value is ordered and greater than infinity.
       Mask = fcNone;
       break;
     }
     default:
-      return {nullptr, fcNone};
+      return {nullptr, fcAllFlags, fcAllFlags};
+    }
+
+    // Invert the comparison for the unordered cases.
+    if (FCmpInst::isUnordered(Pred))
+      Mask = ~Mask;
+
+    return exactClass(Src, Mask);
+  }
+
+  if (Pred == FCmpInst::FCMP_OEQ) {
+    return {Src, RHSClass, fcAllFlags};
+
+  if (Pred == FCmpInst::FCMP_UEQ) {
+    FPClassTest Class = RHSClass | fcNan;
+    return {Src, Class, ~fcNan};
+  }
+
+  if (Pred == FCmpInst::FCMP_ONE)
+    return {Src, ~fcNan, RHSClass};
+
+  if (Pred == FCmpInst::FCMP_UNE)
+    return {Src, fcAllFlags, RHSClass};
+
+  assert((RHSClass == fcPosNormal || RHSClass == fcNegNormal || RHSClass == fcNormal ||
+          RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal || RHSClass == fcSubnormal) &amp;&amp;
+         &quot;should have been recognized as an exact class test&quot;);
+
+  if (IsNegativeRHS) {
+    // TODO: Handle fneg(fabs)
+    if (IsFabs) {
+      // fabs(x) o&gt; -k -&gt; fcmp ord x, x
+      // fabs(x) u&gt; -k -&gt; true
+      // fabs(x) o&lt; -k -&gt; false
+      // fabs(x) u&lt; -k -&gt; fcmp uno x, x
+      switch (Pred) {
+      case FCmpInst::FCMP_OGT:
+      case FCmpInst::FCMP_OGE:
+        return {Src, ~fcNan, fcNan};
+      case FCmpInst::FCMP_UGT:
+      case FCmpInst::FCMP_UGE:
+        return {Src, fcAllFlags, fcNone};
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_OLE:
+        return {Src, fcNone, fcAllFlags};
+      case FCmpInst::FCMP_ULT:
+      case FCmpInst::FCMP_ULE:
+        return {Src, fcNan, ~fcNan};
+      default:
+        break;
+      }
+
+      return {nullptr, fcAllFlags, fcAllFlags};
+    }
+
+    FPClassTest ClassesLE = fcNegInf | fcNegNormal;
+    FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
+
+    if (IsDenormalRHS)
+      ClassesLE |= fcNegSubnormal;
+    else
+      ClassesGE |= fcNegNormal;
+
+    switch (Pred) {
+    case FCmpInst::FCMP_OGT:
+    case FCmpInst::FCMP_OGE:
+      return {Src, ClassesGE, ~ClassesGE | RHSClass};
+    case FCmpInst::FCMP_UGT:
+    case FCmpInst::FCMP_UGE:
+      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
+    case FCmpInst::FCMP_OLT:
+    case FCmpInst::FCMP_OLE:
+      return {Src, ClassesLE, ~ClassesLE | RHSClass};
+    case FCmpInst::FCMP_ULT:
+    case FCmpInst::FCMP_ULE:
+      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
+    default:
+      break;
+    }
+  } else if (IsPositiveRHS) {
+    FPClassTest ClassesGE = fcPosNormal | fcPosInf;
+    FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosNormal;
+    if (IsDenormalRHS)
+      ClassesGE |= fcPosNormal;
+    else
+      ClassesLE |= fcPosSubnormal;
+
+    FPClassTest FalseClasses = RHSClass;
+    if (IsFabs) {
+      ClassesGE = llvm::inverse_fabs(ClassesGE);
+      ClassesLE = llvm::inverse_fabs(ClassesLE);
+    }
+
+    switch (Pred) {
+    case FCmpInst::FCMP_OGT:
+    case FCmpInst::FCMP_OGE:
+      return {Src, ClassesGE, ~ClassesGE | FalseClasses};
+    case FCmpInst::FCMP_UGT:
+    case FCmpInst::FCMP_UGE:
+      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | FalseClasses};
+    case FCmpInst::FCMP_OLT:
+    case FCmpInst::FCMP_OLE:
+      return {Src, ClassesLE, ~ClassesLE | FalseClasses};
+    case FCmpInst::FCMP_ULT:
+    case FCmpInst::FCMP_ULE:
+      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | FalseClasses};
+    default:
+      break;
     }
-  } else if (ConstRHS-&gt;isSmallestNormalized() &amp;&amp; !ConstRHS-&gt;isNegative()) {
+  }
+
+  return {nullptr, fcAllFlags, fcAllFlags};
+}
+
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                       const APFloat &amp;ConstRHS, bool LookThroughSrc) {
+  // We can refine checks against smallest normal / largest denormal to an
+  // exact class test.
+  if (!ConstRHS.isNegative() &amp;&amp; ConstRHS.isSmallestNormalized()) {
+    Value *Src = LHS;
+    const bool IsFabs = LookThroughSrc &amp;&amp; match(LHS, m_FAbs(m_Value(Src)));
+
+    FPClassTest Mask;
     // Match pattern that&#x27;s used in __builtin_isnormal.
     switch (Pred) {
     case FCmpInst::FCMP_OLT:
@@ -4201,20 +4359,28 @@ llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
       break;
     }
     default:
-      return {nullptr, fcNone};
+      // xxx - am i tested
+      return {nullptr, fcAllFlags, fcAllFlags};
+      return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(), LookThroughSrc);
     }
-  } else if (ConstRHS-&gt;isNaN()) {
-    // fcmp o__ x, nan -&gt; false
-    // fcmp u__ x, nan -&gt; true
-    Mask = fcNone;
-  } else
-    return {nullptr, fcNone};
 
-  // Invert the comparison for the unordered cases.
-  if (FCmpInst::isUnordered(Pred))
-    Mask = ~Mask;
+    // Invert the comparison for the unordered cases.
+    if (FCmpInst::isUnordered(Pred))
+      Mask = ~Mask;
 
-  return {Src, Mask};
+    return exactClass(Src, Mask);
+  }
+
+  return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(), LookThroughSrc);
+}
+
+std::tuple&lt;Value *, FPClassTest, FPClassTest&gt;
+llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &amp;F, Value *LHS,
+                       Value *RHS, bool LookThroughSrc) {
+  const APFloat *ConstRHS;
+  if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
+    return {nullptr, fcAllFlags, fcNone};
+  return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
 }
 
 static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
@@ -4241,18 +4407,22 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
     Value *LHS, *RHS;
     uint64_t ClassVal = 0;
     if (match(I-&gt;getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
-      auto [TestedValue, TestedMask] =
-          fcmpToClassTest(Pred, *F, LHS, RHS, true);
-      // First see if we can fold in fabs/fneg into the test.
-      if (TestedValue == V)
-        KnownFromAssume &amp;= TestedMask;
-      else {
-        // Try again without the lookthrough if we found a different source
-        // value.
-        auto [TestedValue, TestedMask] =
-            fcmpToClassTest(Pred, *F, LHS, RHS, false);
-        if (TestedValue == V)
-          KnownFromAssume &amp;= TestedMask;
+      const APFloat *CRHS;
+      if (match(RHS, m_APFloat(CRHS))) {
+        // First see if we can fold in fabs/fneg into the test.
+        auto [CmpVal, MaskIfTrue, MaskIfFalse] =
+            fcmpImpliesClass(Pred, *F, LHS, *CRHS, true);
+        if (CmpVal == V)
+          KnownFromAssume &amp;= (MaskIfTrue &amp; ~MaskIfFalse);
+
+        else {
+          // Try again without the lookthrough if we found a different source
+          // value.
+          auto [CmpVal, MaskIfTrue, MaskIfFalse] =
+              fcmpImpliesClass(Pred, *F, LHS, *CRHS, false);
+          if (CmpVal == V)
+            KnownFromAssume &amp;= (MaskIfTrue &amp; ~MaskIfFalse);
+        }
       }
     } else if (match(I-&gt;getArgOperand(0),
                      m_Intrinsic&lt;Intrinsic::is_fpclass&gt;(
@@ -4400,7 +4570,8 @@ void computeKnownFPClass(const Value *V, const APInt &amp;DemandedElts,
     FPClassTest FilterRHS = fcAllFlags;
 
     Value *TestedValue = nullptr;
-    FPClassTest TestedMask = fcNone;
+    FPClassTest MaskIfTrue = fcAllFlags;
+    FPClassTest MaskIfFalse = fcAllFlags;
     uint64_t ClassVal = 0;
     const Function *F = cast&lt;Instruction&gt;(Op)-&gt;getFunction();
     CmpInst::Predicate Pred;
@@ -4412,20 +4583,22 @@ void computeKnownFPClass(const Value *V, const APInt &amp;DemandedElts,
       // TODO: In some degenerate cases we can infer something if we try again
       // without looking through sign operations.
       bool LookThroughFAbsFNeg = CmpLHS != LHS &amp;&amp; CmpLHS != RHS;
-      std::tie(TestedValue, TestedMask) =
-          fcmpToClassTest(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
+      std::tie(TestedValue, MaskIfTrue, MaskIfFalse) =
+          fcmpImpliesClass(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg);
     } else if (match(Cond,
                      m_Intrinsic&lt;Intrinsic::is_fpclass&gt;(
                          m_Value(TestedValue), m_ConstantInt(ClassVal)))) {
-      TestedMask = static_cast&lt;FPClassTest&gt;(ClassVal);
+      FPClassTest TestedMask = static_cast&lt;FPClassTest&gt;(ClassVal);
+      MaskIfTrue = TestedMask;
+      MaskIfFalse = ~TestedMask;
     }
 
     if (TestedValue == LHS) {
       // match !isnan(x) ? x : y
-      FilterLHS = TestedMask;
-    } else if (TestedValue == RHS) {
+      FilterLHS = MaskIfTrue;
+    } else if (TestedValue == RHS) { // &amp;&amp; IsExactClass
       // match !isnan(x) ? y : x
-      FilterRHS = ~TestedMask;
+      FilterRHS = MaskIfFalse;
     }
 
     KnownFPClass Known2;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 421fc41ad3b62c3..e20f4adeab33dd2 ...

@arsenm
Copy link
Contributor Author

arsenm commented Sep 15, 2023

Broke this and needs rebase

@arsenm arsenm closed this Sep 15, 2023
@arsenm arsenm deleted the fcmp-implies-class-1 branch September 15, 2023 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants