Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ValueTracking] Compute known FPClass from signbit idiom #80740

Merged
merged 4 commits into from
Feb 14, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Feb 5, 2024

This patch improves computeKnownFPClass by using context-sensitive information from DomConditionCache.
The motivation of this patch is to optimize the following case found in fmt/format.h:

define float @test(float %x, i1 %cond) {
  %i32 = bitcast float %x to i32
  %cmp = icmp slt i32 %i32, 0
  br i1 %cmp, label %if.then1, label %if.else

if.then1:
  %fneg = fneg float %x
  br label %if.end

if.else:
  br i1 %cond, label %if.then2, label %if.end

if.then2:
  br label %if.end

if.end:
  %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
  %ret = call float @llvm.fabs.f32(float %value)
  ret float %ret
}

We can prove the sign bit of %value is always zero. Then the fabs can be eliminated.

This pattern also exists in cpython/duckdb/oiio/openexr.

Compile-time impact: https://llvm-compile-time-tracker.com/compare.php?from=f82e0809ba12170e2f648f8a1ac01e78ef06c958&to=041218bf5491996edd828cc15b3aec5a59ddc636&stat=instructions:u

stage1-O3 stage1-ReleaseThinLTO stage1-ReleaseLTO-g stage1-O0-g stage2-O3 stage2-O0-g stage2-clang
-0.00% +0.01% +0.00% -0.03% +0.00% +0.00% +0.02%

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch improves computeKnownFPClass by using context-sensitive information from DomConditionCache.
The motivation of this patch is to optimize the following case found in fmt/format.h:

define float @<!-- -->test(float %x, i1 %cond) {
  %i32 = bitcast float %x to i32
  %cmp = icmp slt i32 %i32, 0
  br i1 %cmp, label %if.then1, label %if.else

if.then1:
  %fneg = fneg float %x
  br label %if.end

if.else:
  br i1 %cond, label %if.then2, label %if.end

if.then2:
  br label %if.end

if.end:
  %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
  %ret = call float @<!-- -->llvm.fabs.f32(float %value)
  ret float %ret
}

We can prove the sign bit of %value is always zero. Then the fabs can be eliminated.

This pattern also exists in cpython/duckdb/oiio/openexr.


Full diff: https://github.com/llvm/llvm-project/pull/80740.diff

3 Files Affected:

  • (modified) llvm/lib/Analysis/DomConditionCache.cpp (+3)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+67-13)
  • (added) llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll (+79)
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index c7f4cab4158880..7c3d23e26d1183 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -51,6 +51,9 @@ static void findAffectedValues(Value *Cond,
       // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
       if (match(A, m_Add(m_Value(X), m_ConstantInt())))
         AddAffected(X);
+      // Handle icmp slt/sgt (bitcast X to int) 0/-1
+      if (match(A, m_BitCast(m_Value(X))))
+        Affected.push_back(X);
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 58db81f470130e..b3315c0bedd874 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4213,9 +4213,56 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
   return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
 }
 
-static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
-                                                  const SimplifyQuery &Q) {
-  FPClassTest KnownFromAssume = fcAllFlags;
+static KnownFPClass computeKnownFPClassFromContext(const Value *V,
+                                                   const SimplifyQuery &Q) {
+  KnownFPClass KnownFromContext;
+
+  if (!Q.CxtI)
+    return KnownFromContext;
+
+  if (Q.DC && Q.DT) {
+    auto computeKnownFPClassFromCmp = [&](CmpInst::Predicate Pred, Value *LHS,
+                                          Value *RHS) {
+      if (match(LHS, m_BitCast(m_Specific(V)))) {
+        Type *SrcType = V->getType();
+        Type *DstType = LHS->getType();
+
+        // Make sure the bitcast doesn't change between scalar and vector and
+        // doesn't change the number of vector elements.
+        if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+            SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
+          // TODO: move IsSignBitCheck to ValueTracking
+          if ((Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) ||
+              (Pred == ICmpInst::ICMP_SLE && match(RHS, m_AllOnes())))
+            KnownFromContext.signBitMustBeOne();
+          else if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()) ||
+                   (Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())))
+            KnownFromContext.signBitMustBeZero();
+        }
+      }
+    };
+
+    // Handle dominating conditions.
+    for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+      // TODO: handle fcmps
+      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
+      if (!Cmp)
+        continue;
+
+      BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
+      if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
+        computeKnownFPClassFromCmp(Cmp->getPredicate(), Cmp->getOperand(0),
+                                   Cmp->getOperand(1));
+
+      BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
+      if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
+        computeKnownFPClassFromCmp(Cmp->getInversePredicate(),
+                                   Cmp->getOperand(0), Cmp->getOperand(1));
+    }
+  }
+
+  if (!Q.AC)
+    return KnownFromContext;
 
   // Try to restrict the floating-point classes based on information from
   // assumptions.
@@ -4242,16 +4289,16 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
         auto [CmpVal, MaskIfTrue, MaskIfFalse] =
             fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
         if (CmpVal == V)
-          KnownFromAssume &= MaskIfTrue;
+          KnownFromContext.knownNot(~MaskIfTrue);
       }
     } else if (match(I->getArgOperand(0),
                      m_Intrinsic<Intrinsic::is_fpclass>(
                          m_Value(LHS), m_ConstantInt(ClassVal)))) {
-      KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
+      KnownFromContext.knownNot(~static_cast<FPClassTest>(ClassVal));
     }
   }
 
-  return KnownFromAssume;
+  return KnownFromContext;
 }
 
 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
@@ -4359,10 +4406,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
       KnownNotFromFlags |= fcInf;
   }
 
-  if (Q.AC) {
-    FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
-    KnownNotFromFlags |= ~AssumedClasses;
-  }
+  KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
+  KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
 
   // We no longer need to find out about these bits from inputs if we can
   // assume this from flags/attributes.
@@ -4370,6 +4415,12 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
 
   auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
     Known.knownNot(KnownNotFromFlags);
+    if (!Known.SignBit && AssumedClasses.SignBit) {
+      if (*AssumedClasses.SignBit)
+        Known.signBitMustBeOne();
+      else
+        Known.signBitMustBeZero();
+    }
   });
 
   if (!Op)
@@ -5271,7 +5322,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
 
       bool First = true;
 
-      for (Value *IncValue : P->incoming_values()) {
+      for (const Use &U : P->operands()) {
+        Value *IncValue = U.get();
         // Skip direct self references.
         if (IncValue == P)
           continue;
@@ -5280,8 +5332,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
         // Recurse, but cap the recursion to two levels, because we don't want
         // to waste time spinning around in loops. We need at least depth 2 to
         // detect known sign bits.
-        computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
-                            PhiRecursionLimit, Q);
+        computeKnownFPClass(
+            IncValue, DemandedElts, InterestedClasses, KnownSrc,
+            PhiRecursionLimit,
+            Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));
 
         if (First) {
           Known = KnownSrc;
diff --git a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
new file mode 100644
index 00000000000000..7338fa176843a6
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
@@ -0,0 +1,79 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define float @test_signbit_check(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then1:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT:    ret float [[VALUE]]
+;
+  %i32 = bitcast float %x to i32
+  %cmp = icmp slt i32 %i32, 0
+  br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+  %fneg = fneg float %x
+  br label %if.end
+
+if.else:
+  br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+  br label %if.end
+
+if.end:
+  %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
+  %ret = call float @llvm.fabs.f32(float %value)
+  ret float %ret
+}
+
+define float @test_signbit_check_fail(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check_fail(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then1:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    [[FNEG2:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[FNEG2]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %i32 = bitcast float %x to i32
+  %cmp = icmp slt i32 %i32, 0
+  br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+  %fneg = fneg float %x
+  br label %if.end
+
+if.else:
+  br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+  %fneg2 = fneg float %x
+  br label %if.end
+
+if.end:
+  %value = phi float [ %fneg, %if.then1 ], [ %fneg2, %if.then2 ], [ %x, %if.else ]
+  %ret = call float @llvm.fabs.f32(float %value)
+  ret float %ret
+}

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Feb 5, 2024
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd generally expect that assume and condition handling reuses the same logic. It seems like now assumes and conditions will handle different (even disjoint) cases?

@dtcxzyw dtcxzyw force-pushed the perf/compute-known-fpclass-from-ctx branch from fa9f293 to 9f02fb4 Compare February 7, 2024 04:56
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Feb 7, 2024

I'd generally expect that assume and condition handling reuses the same logic. It seems like now assumes and conditions will handle different (even disjoint) cases?

I have added fcmp/is.fpclass support (without tests).

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Feb 7, 2024

I will open a new PR for fcmp/is.fpclass support to reduce review effort.

@dtcxzyw dtcxzyw changed the title [ValueTracking] Compute known FPClass from dominating condition [ValueTracking] Compute known FPClass from dominating condition (draft) Feb 7, 2024
dtcxzyw added a commit that referenced this pull request Feb 7, 2024
…0764)

This patch introduces a matching helper `m_ElementWiseBitCast`, which is
used for matching element-wise int <-> fp casts.
The motivation of this patch is to avoid duplicating checks in
#80740 and
#80414.
@arsenm
Copy link
Contributor

arsenm commented Feb 8, 2024

Can you fix the title to make it clear this is the one with the signbit special case?

@dtcxzyw dtcxzyw changed the title [ValueTracking] Compute known FPClass from dominating condition (draft) [ValueTracking] Compute known FPClass from signbit idiom Feb 8, 2024
@dtcxzyw dtcxzyw marked this pull request as draft February 8, 2024 09:09
@dtcxzyw dtcxzyw force-pushed the perf/compute-known-fpclass-from-ctx branch from 0556b5c to 53cb8d8 Compare February 13, 2024 13:34
@dtcxzyw dtcxzyw marked this pull request as ready for review February 13, 2024 13:36
@dtcxzyw dtcxzyw requested a review from arsenm February 13, 2024 13:36
llvm/lib/Analysis/DomConditionCache.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/DomConditionCache.cpp Outdated Show resolved Hide resolved
dtcxzyw added a commit that referenced this pull request Feb 14, 2024
…NFC. (#81704)

This patch moves the `isSignBitCheck` helper into ValueTracking to reuse
the logic in ValueTracking/InstSimplify.

Addresses the comment
#80740 (comment).
@dtcxzyw dtcxzyw force-pushed the perf/compute-known-fpclass-from-ctx branch from 4027ee9 to 041218b Compare February 14, 2024 11:20
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Feb 14, 2024
@dtcxzyw dtcxzyw merged commit 16a0629 into llvm:main Feb 14, 2024
4 checks passed
@dtcxzyw dtcxzyw deleted the perf/compute-known-fpclass-from-ctx branch February 14, 2024 12:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants