Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold Minimum over Trailing/Leading Bits Counts #90402

Merged
merged 2 commits into from
Jul 13, 2024

Conversation

mskamp
Copy link
Contributor

@mskamp mskamp commented Apr 28, 2024

The new transformation folds umin(cttz(x), c) to cttz(x | (1 << c))
and umin(ctlz(x), c) to ctlz(x | ((1 << (bitwidth - 1)) >> c)). The
transformation is only implemented for constant c to not increase the
number of instructions.

The idea of the transformation is to set the c-th lowest (for cttz) or
highest (for ctlz) bit in the operand. In this way, the cttz or
ctlz instruction always returns at most c.

Alive2 proofs: https://alive2.llvm.org/ce/z/y8Hdb8

Fixes #90000

@mskamp mskamp requested a review from nikic as a code owner April 28, 2024 16:07
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 28, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: None (mskamp)

Changes

Fixes #90000


Full diff: https://github.com/llvm/llvm-project/pull/90402.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+12)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+33)
  • (added) llvm/test/Transforms/InstCombine/umin_cttz_ctlz.ll (+355)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 0b13b4aad9c326..36d64c88427883 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2466,6 +2466,18 @@ inline typename m_Intrinsic_Ty<Opnd0>::Ty m_BSwap(const Opnd0 &Op0) {
   return m_Intrinsic<Intrinsic::bswap>(Op0);
 }
 
+template <typename Opnd0, typename Opnd1>
+inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_Ctlz(const Opnd0 &Op0,
+                                                        const Opnd1 &Op1) {
+  return m_Intrinsic<Intrinsic::ctlz>(Op0, Op1);
+}
+
+template <typename Opnd0, typename Opnd1>
+inline typename m_Intrinsic_Ty<Opnd0, Opnd1>::Ty m_Cttz(const Opnd0 &Op0,
+                                                        const Opnd1 &Op1) {
+  return m_Intrinsic<Intrinsic::cttz>(Op0, Op1);
+}
+
 template <typename Opnd0>
 inline typename m_Intrinsic_Ty<Opnd0>::Ty m_FAbs(const Opnd0 &Op0) {
   return m_Intrinsic<Intrinsic::fabs>(Op0);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e5652458f150b5..db742fbe668cc3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1633,6 +1633,39 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       Value *Cmp = Builder.CreateICmpNE(I0, Zero);
       return CastInst::Create(Instruction::ZExt, Cmp, II->getType());
     }
+    // umin(cttz(x), const) --> cttz(x | (1 << const))
+    Value *X;
+    const APInt *Y;
+    Value *Z;
+    if (match(I0, m_OneUse(m_Cttz(m_Value(X), m_Value(Z)))) &&
+        match(I1, m_APInt(Y))) {
+      Value *CttzOp = X;
+      if (Y->ult(I1->getType()->getScalarType()->getIntegerBitWidth())) {
+        auto One = APInt::getOneBitSet(
+            I1->getType()->getScalarType()->getIntegerBitWidth(), 0);
+        Value *NewConst = ConstantInt::get(I1->getType(), One << *Y);
+        CttzOp = Builder.CreateOr(X, NewConst);
+      }
+      return CallInst::Create(Intrinsic::getDeclaration(II->getModule(),
+                                                        Intrinsic::cttz,
+                                                        II->getType()),
+                              {CttzOp, Z});
+    }
+    // umin(ctlz(x), const) --> ctlz(x | ((1 << (bitwidth - 1) >> const)))
+    if (match(I0, m_OneUse(m_Ctlz(m_Value(X), m_Value(Z)))) &&
+        match(I1, m_APInt(Y))) {
+      Value *CtlzOp = X;
+      if (Y->ult(I1->getType()->getScalarType()->getIntegerBitWidth())) {
+        auto Min = APInt::getSignedMinValue(
+            I1->getType()->getScalarType()->getIntegerBitWidth());
+        Value *NewConst = ConstantInt::get(I1->getType(), Min.lshr(*Y));
+        CtlzOp = Builder.CreateOr(X, NewConst);
+      }
+      return CallInst::Create(Intrinsic::getDeclaration(II->getModule(),
+                                                        Intrinsic::ctlz,
+                                                        II->getType()),
+                              {CtlzOp, Z});
+    }
     [[fallthrough]];
   }
   case Intrinsic::umax: {
diff --git a/llvm/test/Transforms/InstCombine/umin_cttz_ctlz.ll b/llvm/test/Transforms/InstCombine/umin_cttz_ctlz.ll
new file mode 100644
index 00000000000000..91f5b818c7ff9a
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/umin_cttz_ctlz.ll
@@ -0,0 +1,355 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare i1 @llvm.umin.i1(i1 %a, i1 %b)
+declare i8 @llvm.umin.i8(i8 %a, i8 %b)
+declare i16 @llvm.umin.i16(i16 %a, i16 %b)
+declare i32 @llvm.umin.i32(i32 %a, i32 %b)
+declare i64 @llvm.umin.i64(i64 %a, i64 %b)
+declare <2 x i32> @llvm.umin.v2i32(<2 x i32> %a, <2 x i32> %b)
+
+declare i1 @llvm.cttz.i1(i1, i1)
+declare i8 @llvm.cttz.i8(i8, i1)
+declare i16 @llvm.cttz.i16(i16, i1)
+declare i32 @llvm.cttz.i32(i32, i1)
+declare i64 @llvm.cttz.i64(i64, i1)
+declare <2 x i32> @llvm.cttz.v2i32(<2 x i32>, i1)
+
+declare i1 @llvm.ctlz.i1(i1, i1)
+declare i8 @llvm.ctlz.i8(i8, i1)
+declare i16 @llvm.ctlz.i16(i16, i1)
+declare i32 @llvm.ctlz.i32(i32, i1)
+declare i64 @llvm.ctlz.i64(i64, i1)
+declare <2 x i32> @llvm.ctlz.v2i32(<2 x i32>, i1)
+
+define i8 @umin_cttz_i8_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_cttz_i8_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.cttz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %cttz = call i8 @llvm.cttz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 %cttz, i8 6)
+  ret i8 %ret
+}
+
+define i8 @umin_cttz_i8_zero_defined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_cttz_i8_zero_defined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.cttz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %cttz = call i8 @llvm.cttz.i8(i8 %X, i1 false)
+  %ret = call i8 @llvm.umin.i8(i8 %cttz, i8 6)
+  ret i8 %ret
+}
+
+define i8 @umin_cttz_i8_commuted_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_cttz_i8_commuted_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.cttz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %cttz = call i8 @llvm.cttz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 6, i8 %cttz)
+  ret i8 %ret
+}
+
+define i8 @umin_cttz_i8_ge_bitwidth_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_cttz_i8_ge_bitwidth_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[X]], i1 true)
+; CHECK-NEXT:    ret i8 [[CTTZ]]
+;
+  %cttz = call i8 @llvm.cttz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 %cttz, i8 10)
+  ret i8 %ret
+}
+
+define i16 @umin_cttz_i16_zero_undefined(i16 %X) {
+; CHECK-LABEL: define i16 @umin_cttz_i16_zero_undefined(
+; CHECK-SAME: i16 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i16 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i16 0, 7) i16 @llvm.cttz.i16(i16 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %cttz = call i16 @llvm.cttz.i16(i16 %X, i1 true)
+  %ret = call i16 @llvm.umin.i16(i16 %cttz, i16 6)
+  ret i16 %ret
+}
+
+define i32 @umin_cttz_i32_zero_undefined(i32 %X) {
+; CHECK-LABEL: define i32 @umin_cttz_i32_zero_undefined(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i32 0, 7) i32 @llvm.cttz.i32(i32 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %cttz = call i32 @llvm.cttz.i32(i32 %X, i1 true)
+  %ret = call i32 @llvm.umin.i32(i32 %cttz, i32 6)
+  ret i32 %ret
+}
+
+define i64 @umin_cttz_i64_zero_undefined(i64 %X) {
+; CHECK-LABEL: define i64 @umin_cttz_i64_zero_undefined(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i64 [[X]], 64
+; CHECK-NEXT:    [[RET:%.*]] = call range(i64 0, 7) i64 @llvm.cttz.i64(i64 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i64 [[RET]]
+;
+  %cttz = call i64 @llvm.cttz.i64(i64 %X, i1 true)
+  %ret = call i64 @llvm.umin.i64(i64 %cttz, i64 6)
+  ret i64 %ret
+}
+
+define i1 @umin_cttz_i1_zero_undefined(i1 %X) {
+; CHECK-LABEL: define i1 @umin_cttz_i1_zero_undefined(
+; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-NEXT:    ret i1 false
+;
+  %cttz = call i1 @llvm.cttz.i1(i1 %X, i1 true)
+  %ret = call i1 @llvm.umin.i1(i1 %cttz, i1 1)
+  ret i1 %ret
+}
+
+define i1 @umin_cttz_i1_zero_defined(i1 %X) {
+; CHECK-LABEL: define i1 @umin_cttz_i1_zero_defined(
+; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = xor i1 [[X]], true
+; CHECK-NEXT:    ret i1 [[CTTZ]]
+;
+  %cttz = call i1 @llvm.cttz.i1(i1 %X, i1 false)
+  %ret = call i1 @llvm.umin.i1(i1 %cttz, i1 1)
+  ret i1 %ret
+}
+
+define <2 x i32> @umin_cttz_2xi32_splat_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_cttz_2xi32_splat_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or <2 x i32> [[X]], <i32 64, i32 64>
+; CHECK-NEXT:    [[RET:%.*]] = call range(i32 0, 7) <2 x i32> @llvm.cttz.v2i32(<2 x i32> [[TMP1]], i1 true)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %cttz = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %cttz, <2 x i32> <i32 6, i32 6>)
+  ret <2 x i32> %ret
+}
+
+define <2 x i32> @umin_cttz_2xi32_splat_poison_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_cttz_2xi32_splat_poison_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i32 0, 33) <2 x i32> @llvm.cttz.v2i32(<2 x i32> [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call <2 x i32> @llvm.umin.v2i32(<2 x i32> [[CTTZ]], <2 x i32> <i32 6, i32 poison>)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %cttz = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %cttz, <2 x i32> <i32 6, i32 poison>)
+  ret <2 x i32> %ret
+}
+
+define <2 x i32> @umin_cttz_2xi32_no_splat_negative_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_cttz_2xi32_no_splat_negative_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i32 0, 33) <2 x i32> @llvm.cttz.v2i32(<2 x i32> [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call <2 x i32> @llvm.umin.v2i32(<2 x i32> [[CTTZ]], <2 x i32> <i32 6, i32 0>)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %cttz = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %cttz, <2 x i32> <i32 6, i32 0>)
+  ret <2 x i32> %ret
+}
+
+define i16 @umin_cttz_i16_negative_non_constant(i16 %X, i16 %Y) {
+; CHECK-LABEL: define i16 @umin_cttz_i16_negative_non_constant(
+; CHECK-SAME: i16 [[X:%.*]], i16 [[Y:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i16 0, 17) i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.umin.i16(i16 [[CTTZ]], i16 [[Y]])
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %cttz = call i16 @llvm.cttz.i16(i16 %X, i1 true)
+  %ret = call i16 @llvm.umin.i16(i16 %cttz, i16 %Y)
+  ret i16 %ret
+}
+
+define i16 @umin_cttz_i16_negative_two_uses(i16 %X) {
+; CHECK-LABEL: define i16 @umin_cttz_i16_negative_two_uses(
+; CHECK-SAME: i16 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i16 0, 17) i16 @llvm.cttz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    [[OP0:%.*]] = call i16 @llvm.umin.i16(i16 [[CTTZ]], i16 6)
+; CHECK-NEXT:    [[RET:%.*]] = add nuw nsw i16 [[CTTZ]], [[OP0]]
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %cttz = call i16 @llvm.cttz.i16(i16 %X, i1 true)
+  %op0 = call i16 @llvm.umin.i16(i16 %cttz, i16 6)
+  %ret = add i16 %cttz, %op0
+  ret i16 %ret
+}
+
+define i8 @umin_ctlz_i8_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_ctlz_i8_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 2
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.ctlz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %ctlz = call i8 @llvm.ctlz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 %ctlz, i8 6)
+  ret i8 %ret
+}
+
+define i8 @umin_ctlz_i8_zero_defined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_ctlz_i8_zero_defined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 2
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.ctlz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %ctlz = call i8 @llvm.ctlz.i8(i8 %X, i1 false)
+  %ret = call i8 @llvm.umin.i8(i8 %ctlz, i8 6)
+  ret i8 %ret
+}
+
+define i8 @umin_ctlz_i8_commuted_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_ctlz_i8_commuted_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[X]], 2
+; CHECK-NEXT:    [[RET:%.*]] = call range(i8 0, 7) i8 @llvm.ctlz.i8(i8 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %ctlz = call i8 @llvm.ctlz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 6, i8 %ctlz)
+  ret i8 %ret
+}
+
+define i8 @umin_ctlz_i8_ge_bitwidth_zero_undefined(i8 %X) {
+; CHECK-LABEL: define i8 @umin_ctlz_i8_ge_bitwidth_zero_undefined(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTLZ:%.*]] = call range(i8 0, 9) i8 @llvm.ctlz.i8(i8 [[X]], i1 true)
+; CHECK-NEXT:    ret i8 [[CTLZ]]
+;
+  %ctlz = call i8 @llvm.ctlz.i8(i8 %X, i1 true)
+  %ret = call i8 @llvm.umin.i8(i8 %ctlz, i8 10)
+  ret i8 %ret
+}
+
+define i16 @umin_ctlz_i16_zero_undefined(i16 %X) {
+; CHECK-LABEL: define i16 @umin_ctlz_i16_zero_undefined(
+; CHECK-SAME: i16 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i16 [[X]], 512
+; CHECK-NEXT:    [[RET:%.*]] = call range(i16 0, 7) i16 @llvm.ctlz.i16(i16 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %ctlz = call i16 @llvm.ctlz.i16(i16 %X, i1 true)
+  %ret = call i16 @llvm.umin.i16(i16 %ctlz, i16 6)
+  ret i16 %ret
+}
+
+define i32 @umin_ctlz_i32_zero_undefined(i32 %X) {
+; CHECK-LABEL: define i32 @umin_ctlz_i32_zero_undefined(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[X]], 33554432
+; CHECK-NEXT:    [[RET:%.*]] = call range(i32 0, 7) i32 @llvm.ctlz.i32(i32 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %ctlz = call i32 @llvm.ctlz.i32(i32 %X, i1 true)
+  %ret = call i32 @llvm.umin.i32(i32 %ctlz, i32 6)
+  ret i32 %ret
+}
+
+define i64 @umin_ctlz_i64_zero_undefined(i64 %X) {
+; CHECK-LABEL: define i64 @umin_ctlz_i64_zero_undefined(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or i64 [[X]], 144115188075855872
+; CHECK-NEXT:    [[RET:%.*]] = call range(i64 0, 7) i64 @llvm.ctlz.i64(i64 [[TMP1]], i1 true)
+; CHECK-NEXT:    ret i64 [[RET]]
+;
+  %ctlz = call i64 @llvm.ctlz.i64(i64 %X, i1 true)
+  %ret = call i64 @llvm.umin.i64(i64 %ctlz, i64 6)
+  ret i64 %ret
+}
+
+define i1 @umin_ctlz_i1_zero_undefined(i1 %X) {
+; CHECK-LABEL: define i1 @umin_ctlz_i1_zero_undefined(
+; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-NEXT:    ret i1 false
+;
+  %ctlz = call i1 @llvm.ctlz.i1(i1 %X, i1 true)
+  %ret = call i1 @llvm.umin.i1(i1 %ctlz, i1 1)
+  ret i1 %ret
+}
+
+define i1 @umin_ctlz_i1_zero_defined(i1 %X) {
+; CHECK-LABEL: define i1 @umin_ctlz_i1_zero_defined(
+; CHECK-SAME: i1 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTLZ:%.*]] = xor i1 [[X]], true
+; CHECK-NEXT:    ret i1 [[CTLZ]]
+;
+  %ctlz = call i1 @llvm.ctlz.i1(i1 %X, i1 false)
+  %ret = call i1 @llvm.umin.i1(i1 %ctlz, i1 1)
+  ret i1 %ret
+}
+
+define <2 x i32> @umin_ctlz_2xi32_splat_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_ctlz_2xi32_splat_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = or <2 x i32> [[X]], <i32 33554432, i32 33554432>
+; CHECK-NEXT:    [[RET:%.*]] = call range(i32 0, 7) <2 x i32> @llvm.ctlz.v2i32(<2 x i32> [[TMP1]], i1 true)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %ctlz = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %ctlz, <2 x i32> <i32 6, i32 6>)
+  ret <2 x i32> %ret
+}
+
+define <2 x i32> @umin_ctlz_2xi32_splat_poison_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_ctlz_2xi32_splat_poison_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[CTLZ:%.*]] = call range(i32 0, 33) <2 x i32> @llvm.ctlz.v2i32(<2 x i32> [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call <2 x i32> @llvm.umin.v2i32(<2 x i32> [[CTLZ]], <2 x i32> <i32 6, i32 poison>)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %ctlz = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %ctlz, <2 x i32> <i32 6, i32 poison>)
+  ret <2 x i32> %ret
+}
+
+define <2 x i32> @umin_ctlz_2xi32_no_splat_negative_zero_undefined(<2 x i32> %X) {
+; CHECK-LABEL: define <2 x i32> @umin_ctlz_2xi32_no_splat_negative_zero_undefined(
+; CHECK-SAME: <2 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[CTLZ:%.*]] = call range(i32 0, 33) <2 x i32> @llvm.ctlz.v2i32(<2 x i32> [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call <2 x i32> @llvm.umin.v2i32(<2 x i32> [[CTLZ]], <2 x i32> <i32 6, i32 0>)
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %ctlz = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %X, i1 true)
+  %ret = call <2 x i32> @llvm.umin.v2i32(<2 x i32> %ctlz, <2 x i32> <i32 6, i32 0>)
+  ret <2 x i32> %ret
+}
+
+define i16 @umin_ctlz_i16_negative_non_constant(i16 %X, i16 %Y) {
+; CHECK-LABEL: define i16 @umin_ctlz_i16_negative_non_constant(
+; CHECK-SAME: i16 [[X:%.*]], i16 [[Y:%.*]]) {
+; CHECK-NEXT:    [[CTLZ:%.*]] = call range(i16 0, 17) i16 @llvm.ctlz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    [[RET:%.*]] = call i16 @llvm.umin.i16(i16 [[CTLZ]], i16 [[Y]])
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %ctlz = call i16 @llvm.ctlz.i16(i16 %X, i1 true)
+  %ret = call i16 @llvm.umin.i16(i16 %ctlz, i16 %Y)
+  ret i16 %ret
+}
+
+define i16 @umin_ctlz_i16_negative_two_uses(i16 %X) {
+; CHECK-LABEL: define i16 @umin_ctlz_i16_negative_two_uses(
+; CHECK-SAME: i16 [[X:%.*]]) {
+; CHECK-NEXT:    [[CTTZ:%.*]] = call range(i16 0, 17) i16 @llvm.ctlz.i16(i16 [[X]], i1 true)
+; CHECK-NEXT:    [[OP0:%.*]] = call i16 @llvm.umin.i16(i16 [[CTTZ]], i16 6)
+; CHECK-NEXT:    [[RET:%.*]] = add nuw nsw i16 [[CTTZ]], [[OP0]]
+; CHECK-NEXT:    ret i16 [[RET]]
+;
+  %ctlz = call i16 @llvm.ctlz.i16(i16 %X, i1 true)
+  %op0 = call i16 @llvm.umin.i16(i16 %ctlz, i16 6)
+  %ret = add i16 %ctlz, %op0
+  ret i16 %ret
+}

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use value tracking and support non-constant values and non-uniform vectors?

@nikic
Copy link
Contributor

nikic commented Apr 29, 2024

Could we use value tracking and support non-constant values and non-uniform vectors?

I don't think this transform is profitable for non-constant values.

@RKSimon
Copy link
Collaborator

RKSimon commented Apr 29, 2024

Could we use value tracking and support non-constant values and non-uniform vectors?

I don't think this transform is profitable for non-constant values.

But I'd still like this for non-uniform constant vectors

@mskamp mskamp force-pushed the fix_90000_umin_cttz_ctlz branch from c727b10 to 2907f4f Compare May 1, 2024 07:40
@mskamp
Copy link
Contributor Author

mskamp commented May 1, 2024

@nikic @RKSimon Thank you very much for your feedback. I've added support for non-uniform vectors to the implementation. Unfortunately, this has resulted in some substantial code changes but I've still tried to address your comments as much as possible.

nikic added a commit to nikic/llvm-project that referenced this pull request May 1, 2024
… (NFC)

The InstCombine contributor guide already says:

> Handle non-splat vector constants if doing so is free, but do
> not add handling for them if it adds any additional complexity
> to the code.

I would like to strengthen this guideline to explicitly forbid
asking contributors to implement non-splat support during code
reviews.

I've found that the outcome is pretty much always bad whenever
this request is made. Most recent example is in llvm#90402, which
initially had a reasonable implementation of a fold without
non-splat support. In response to reviewer feedback, it was
adjusted to use a more complex implementation that supports
non-splat vectors. Now I have the choice between accepting this
unnecessary complexity into InstCombine, or asking a first-time
contributor to undo their changes, which is really not something
I want to do.

Complex non-splat vector handling has done significant damage to
the InstCombine code base in the past (mostly due to the work of
a single contributor) and I am quite wary of repeating this
mistake.
@goldsteinn
Copy link
Contributor

For a PR like this, proofs for just i8 is sufficient. Also you don't need to declare llvm intrinsics. https://alive2.llvm.org/ce/z/y8Hdb8 would have been sufficient.

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
? ConstantInt::get(Ty, 1)
: ConstantInt::get(Ty, APInt::getSignedMinValue(BitWidth)),
ConstOp, DL),
Constant::getNullValue(Ty));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go with this implementation, you'll also have to adjust the alive2 proofs in the PR description to be in terms of this implementation rather than assumes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this still hold for the new implementation? The branches are now explicitly stated in the code, which is why the proofs still rely on assumes.

nikic added a commit to nikic/llvm-project that referenced this pull request May 3, 2024
… (NFC)

The InstCombine contributor guide already says:

> Handle non-splat vector constants if doing so is free, but do
> not add handling for them if it adds any additional complexity
> to the code.

I would like to strengthen this guideline to explicitly forbid
asking contributors to implement non-splat support during code
reviews.

I've found that the outcome is pretty much always bad whenever
this request is made. Most recent example is in llvm#90402, which
initially had a reasonable implementation of a fold without
non-splat support. In response to reviewer feedback, it was
adjusted to use a more complex implementation that supports
non-splat vectors. Now I have the choice between accepting this
unnecessary complexity into InstCombine, or asking a first-time
contributor to undo their changes, which is really not something
I want to do.

Complex non-splat vector handling has done significant damage to
the InstCombine code base in the past (mostly due to the work of
a single contributor) and I am quite wary of repeating this
mistake.
@mskamp mskamp force-pushed the fix_90000_umin_cttz_ctlz branch from 2907f4f to d58f626 Compare May 4, 2024 20:39
@mskamp
Copy link
Contributor Author

mskamp commented Jun 21, 2024

Ping. :) Any more comments or suggestions for improvement on this?

@RKSimon RKSimon requested a review from dtcxzyw June 21, 2024 11:21
@mskamp mskamp force-pushed the fix_90000_umin_cttz_ctlz branch 2 times, most recently from 2dc1346 to 5502bef Compare June 21, 2024 15:35
Comment on lines 1462 to 1464
if (match(I1, m_CheckedInt(std::not_fn(LessBitWidth)))) {
return I0;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it handled by CVP/SCCP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but apparently only for scalars. In contrast, the transformation here also handles splat and non-splat constant vectors. At least that's the result of some quick tests I did; and after a quick glance at the code, I think that ValueLattice::markConstant only handles ConstantInt and therefore we get a „full“ constant range in CVP for vector constants.

If this case (all elements >= BitWidth) shouldn't be handled in InstCombine, I can of course remove this case here and only handle the case that all vector elements are < BitWidth. Should I do this?

Copy link
Contributor

@goldsteinn goldsteinn Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think the argument then would be to just add a separate fold for handling non-splat out of bounds shift case. This could easily fit in InstSimplify.

Personally agnostic to whether this stays or is removed. Although I doubt this will ever realistically be hit outside of a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the handling for constants >= BitWidth. Since I agree that this case is probably never encountered, it's not worth the effort to maintain this.


if (I0->hasOneUse()) {
if (auto *II0 = dyn_cast<IntrinsicInst>(I0);
II0 && II0->getIntrinsicID() == IntrID) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this match(I0, m_OneUse(m_Intrinsic<IntrID>(m_Value(X), m_Value(Z)))? Need to change the IntrID to a template param for that.

Also early return please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation was suggested above. I've nevertheless converted it to return early.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn Were you aware that this would prevent use of m_Intrinsic when you made that suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't thinking that no, if you want to revert the suggestion I don't oppose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed an implementation that uses m_Intrinsic.

@mskamp mskamp force-pushed the fix_90000_umin_cttz_ctlz branch 2 times, most recently from bf63243 to ff7c786 Compare July 2, 2024 17:02
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, some nits.

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
foldMinimumOverTrailingOrLeadingZeroCount<Intrinsic::ctlz>(
I0, I1, DL, Builder)) {
return replaceInstUsesWith(*II, FoldedCtlz);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Omit braces.

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
@nikic
Copy link
Contributor

nikic commented Jul 2, 2024

Looks like this needs a rebase.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
)

The new transformation folds `umin(cttz(x), c)` to `cttz(x | (1 << c))`
and `umin(ctlz(x), c)` to `ctlz(x | ((1 << (bitwidth - 1)) >> c))`. The
transformation is only implemented for constant `c` to not increase the
number of instructions.

The idea of the transformation is to set the c-th lowest (for `cttz`) or
highest (for `ctlz`) bit in the operand. In this way, the `cttz` or
`ctlz` instruction always returns at most `c`.

Alive2 proofs: https://alive2.llvm.org/ce/z/y8Hdb8
@nikic nikic merged commit 949bbdc into llvm:main Jul 13, 2024
7 checks passed
Copy link

@mskamp Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
)

The new transformation folds `umin(cttz(x), c)` to `cttz(x | (1 << c))`
and `umin(ctlz(x), c)` to `ctlz(x | ((1 << (bitwidth - 1)) >> c))`. The
transformation is only implemented for constant `c` to not increase the
number of instructions.
    
The idea of the transformation is to set the c-th lowest (for `cttz`) or
highest (for `ctlz`) bit in the operand. In this way, the `cttz` or
`ctlz` instruction always returns at most `c`.
    
Alive2 proofs: https://alive2.llvm.org/ce/z/y8Hdb8

Fixes llvm#90000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@min(@ctz(x), y) can become @ctz(x | (1 << y))
6 participants