Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x))) #77859

Closed

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Jan 12, 2024

  • [InstCombine] Add tests for folding (add/sub/disjoint_or/icmp C, (ctpop (not x))); NFC
  • [InstCombine] Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x)))

(ctpop (not x)) <-> (sub nuw nsw BitWidth(x), (ctpop x)). The
sub expression can sometimes be constant folded depending on the use
case of (ctpop (not x)).

This patch adds fold for the following cases:

(add/sub/disjoint_or C, (ctpop (not x))
-> (add/sub/disjoint_or C', (ctpop x))
(cmp pred C, (ctpop (not x))
-> (cmp swapped_pred C', (ctpop x))

Where C' depends on how we constant fold C with BitWidth(x) for
the given opcode.

Proofs: https://alive2.llvm.org/ce/z/qUgfF3

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 12, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes
  • Add tests for folding (add/sub/disjoint_or/icmp C, (ctpop (not x))); NFC
  • Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x)))

Full diff: https://github.com/llvm/llvm-project/pull/77859.diff

6 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+77)
  • (added) llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll (+166)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c7e6f32c5406a6..8a00b75a1f7404 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1683,6 +1683,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   // TODO(jingyue): Consider willNotOverflowSignedAdd and
   // willNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
@@ -2445,6 +2448,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   if (Instruction *R = foldSubOfMinMax(I, Builder))
     return R;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0620752e321394..de06fb8badf817 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3398,6 +3398,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = foldBinOpShiftWithShift(I))
     return R;
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   Value *X, *Y;
   const APInt *CV;
   if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c1aff445524de..8c0fd662255130 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1323,6 +1323,9 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
       return replaceInstUsesWith(Cmp, NewPhi);
     }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&Cmp))
+    return R;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 21c61bd990184d..c24b6e3a5b33c0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -505,6 +505,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
                                         Value *RHS);
 
+  // If `I` has operand `(ctpop (not x))`, fold `I` with `(sub nuw nsw
+  // BitWidth(x), (ctpop x))`.
+  Instruction *tryFoldInstWithCtpopWithNot(Instruction *I);
+
   // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
   //    -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C)
   // (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 7f2018b3a19958..732ab7ad8b3223 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -740,6 +740,83 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
   return RetVal;
 }
 
+// If `I` has one Const operand and the other matches `(ctpop (not x))`,
+// replace `(ctpop (not x))` with `(sub nuw nsw BitWidth(x), (ctpop x))`.
+// This is only useful is the new subtract can fold so we only handle the
+// following cases:
+//    1) (add/sub/disjoint_or C, (ctpop (not x))
+//        -> (add/sub/disjoint_or C', (ctpop x))
+//    1) (cmp pred C, (ctpop (not x))
+//        -> (cmp pred C', (ctpop x))
+Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
+  unsigned Opc = I->getOpcode();
+  unsigned ConstIdx = 1;
+  switch (Opc) {
+  default:
+    return nullptr;
+    // (ctpop (not x)) <-> (sub nuw nsw BitWidth(x) - (ctpop x))
+    // We can fold the BitWidth(x) with add/sub/icmp as long the other operand
+    // is constant.
+  case Instruction::Sub:
+    ConstIdx = 0;
+    break;
+  case Instruction::Or:
+    if (!match(I, m_DisjointOr(m_Value(), m_Value())))
+      return nullptr;
+    [[fallthrough]];
+  case Instruction::Add:
+  case Instruction::ICmp:
+    break;
+  }
+  // Find ctpop.
+  auto *Ctpop = dyn_cast<IntrinsicInst>(I->getOperand(1 - ConstIdx));
+  if (Ctpop == nullptr)
+    return nullptr;
+  if (Ctpop->getIntrinsicID() != Intrinsic::ctpop)
+    return nullptr;
+  Constant *C;
+  // Check other operand is ImmConstant.
+  if (!match(I->getOperand(ConstIdx), m_ImmConstant(C)))
+    return nullptr;
+
+  Type *Ty = Ctpop->getType();
+  Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
+  // Need extra check for icmp. Note if this check is it generally means the
+  // icmp will simplify to true/false.
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
+      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
+    return nullptr;
+
+  Value *Op = Ctpop->getArgOperand(0);
+  // Check we can invert `(not x)` for free.
+  Value *NotOp = getFreelyInverted(Op, Op->hasOneUse(), &Builder);
+  if (NotOp == nullptr)
+    return nullptr;
+  Value *CtpopOfNotOp = Builder.CreateIntrinsic(Ty, Intrinsic::ctpop, NotOp);
+
+  Value *R = nullptr;
+
+  // Do the transformation here to avoid potentially introducing an infinite
+  // loop.
+  switch (Opc) {
+  case Instruction::Sub:
+    R = Builder.CreateAdd(CtpopOfNotOp, ConstantExpr::getSub(C, BitWidthC));
+    break;
+  case Instruction::Or:
+  case Instruction::Add:
+    R = Builder.CreateSub(ConstantExpr::getAdd(C, BitWidthC), CtpopOfNotOp);
+    break;
+  case Instruction::ICmp:
+    R = Builder.CreateICmp(cast<ICmpInst>(I)->getSwappedPredicate(),
+                           CtpopOfNotOp, ConstantExpr::getSub(BitWidthC, C));
+    break;
+  default:
+    llvm_unreachable("Unhandled Opcode");
+  }
+  assert(R != nullptr);
+  return replaceInstUsesWith(*I, R);
+}
+
 // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
 //   IFF
 //    1) the logic_shifts match
diff --git a/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll b/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll
new file mode 100644
index 00000000000000..9fa3bb66bb7f10
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll
@@ -0,0 +1,166 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare i8 @llvm.ctpop.i8(i8)
+declare <2 x i8> @llvm.ctpop.v2i8(<2 x i8>)
+
+define i8 @fold_sub_c_ctpop(i8 %x) {
+; CHECK-LABEL: @fold_sub_c_ctpop(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0:![0-9]+]]
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i8 [[TMP1]], 4
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = sub i8 12, %cnt
+  ret i8 %r
+}
+
+define i8 @fold_sub_var_ctpop_fail(i8 %x, i8 %y) {
+; CHECK-LABEL: @fold_sub_var_ctpop_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y:%.*]], [[CNT]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = sub i8 %y, %cnt
+  ret i8 %r
+}
+
+define <2 x i8> @fold_sub_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_sub_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw <2 x i8> <i8 -55, i8 -56>, [[TMP1]]
+; CHECK-NEXT:    ret <2 x i8> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = sub <2 x i8> %cnt, <i8 63, i8 64>
+  ret <2 x i8> %r
+}
+
+define i8 @fold_add_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_add_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw i8 71, [[TMP1]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = add i8 %cnt, 63
+  ret i8 %r
+}
+
+define i8 @fold_distjoint_or_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_distjoint_or_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw i8 72, [[TMP1]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = or i8 %cnt, 64
+  ret i8 %r
+}
+
+define i8 @fold_or_ctpop_c_fail(i8 %x) {
+; CHECK-LABEL: @fold_or_ctpop_c_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[CNT]], 65
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = or i8 %cnt, 65
+  ret i8 %r
+}
+
+define i8 @fold_add_ctpop_var_fail(i8 %x, i8 %y) {
+; CHECK-LABEL: @fold_add_ctpop_var_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[CNT]], [[Y:%.*]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = add i8 %cnt, %y
+  ret i8 %r
+}
+
+define i1 @fold_cmp_eq_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_cmp_eq_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[TMP1]], 6
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = icmp eq i8 %cnt, 2
+  ret i1 %r
+}
+
+define <2 x i1> @fold_cmp_ne_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ne_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i8> [[TMP1]], <i8 -36, i8 5>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ne <2 x i8> %cnt, <i8 44, i8 3>
+  ret <2 x i1> %r
+}
+
+define <2 x i1> @fold_cmp_ne_ctpop_var_fail(<2 x i8> %x, <2 x i8> %y) {
+; CHECK-LABEL: @fold_cmp_ne_ctpop_var_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
+; CHECK-NEXT:    [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i8> [[CNT]], [[Y:%.*]]
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ne <2 x i8> %cnt, %y
+  ret <2 x i1> %r
+}
+
+define i1 @fold_cmp_ult_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_cmp_ult_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[TMP1]], 3
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = icmp ult i8 %cnt, 5
+  ret i1 %r
+}
+
+define <2 x i1> @fold_cmp_ugt_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ugt_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ult <2 x i8> [[TMP1]], <i8 0, i8 2>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ugt <2 x i8> %cnt, <i8 8, i8 6>
+  ret <2 x i1> %r
+}
+
+define <2 x i1> @fold_cmp_ugt_ctpop_c_out_of_range_fail(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ugt_ctpop_c_out_of_range_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
+; CHECK-NEXT:    [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt <2 x i8> [[CNT]], <i8 2, i8 10>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ugt <2 x i8> %cnt, <i8 2, i8 10>
+  ret <2 x i1> %r
+}

@goldsteinn goldsteinn changed the title users/goldsteinn/ctpop of not Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x))) Jan 12, 2024
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jan 12, 2024
@goldsteinn goldsteinn changed the title Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x))) [InstCombine] Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x))) Jan 12, 2024
… x)))`

`(ctpop (not x))` <-> `(sub nuw nsw BitWidth(x), (ctpop x))`. The
`sub` expression can sometimes be constant folded depending on the use
case of `(ctpop (not x))`.

This patch adds fold for the following cases:

`(add/sub/disjoint_or C, (ctpop (not x))`
    -> `(add/sub/disjoint_or C', (ctpop x))`
`(cmp pred C, (ctpop (not x))`
    -> `(cmp swapped_pred C', (ctpop x))`

Where `C'` depends on how we constant fold `C` with `BitWidth(x)` for
the given opcode.

Proofs: https://alive2.llvm.org/ce/z/qUgfF3
Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please wait for approval from @nikic.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


Type *Ty = Op->getType();
Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
// Need extra check for icmp. Note if this check is it generally means the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something is missing after "if this check is"?

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
… x)))`

`(ctpop (not x))` <-> `(sub nuw nsw BitWidth(x), (ctpop x))`. The
`sub` expression can sometimes be constant folded depending on the use
case of `(ctpop (not x))`.

This patch adds fold for the following cases:

`(add/sub/disjoint_or C, (ctpop (not x))`
    -> `(add/sub/disjoint_or C', (ctpop x))`
`(cmp pred C, (ctpop (not x))`
    -> `(cmp swapped_pred C', (ctpop x))`

Where `C'` depends on how we constant fold `C` with `BitWidth(x)` for
the given opcode.

Proofs: https://alive2.llvm.org/ce/z/qUgfF3

Closes llvm#77859
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants