Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add support for ignoring poisons in getSplatValue #89155

Closed
wants to merge 1 commit into from

Conversation

goldsteinn
Copy link
Contributor

This is a followup to #88217

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes

This is a followup to #88217


Patch is 59.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89155.diff

27 Files Affected:

  • (modified) llvm/include/llvm/IR/Constant.h (+2-1)
  • (modified) llvm/include/llvm/IR/Constants.h (+2-1)
  • (modified) llvm/include/llvm/IR/PatternMatch.h (+14-10)
  • (modified) llvm/lib/IR/Constants.cpp (+22-9)
  • (modified) llvm/test/Analysis/ValueTracking/known-non-zero.ll (+15-19)
  • (modified) llvm/test/Transforms/InstCombine/add.ll (+9-12)
  • (modified) llvm/test/Transforms/InstCombine/and.ll (+13-15)
  • (modified) llvm/test/Transforms/InstCombine/ashr-lshr.ll (+4-4)
  • (modified) llvm/test/Transforms/InstCombine/cast.ll (+6-6)
  • (modified) llvm/test/Transforms/InstCombine/getelementptr.ll (+6-8)
  • (modified) llvm/test/Transforms/InstCombine/icmp-logical.ll (+7-11)
  • (modified) llvm/test/Transforms/InstCombine/icmp-range.ll (+15-17)
  • (modified) llvm/test/Transforms/InstCombine/icmp-shr.ll (+4-6)
  • (modified) llvm/test/Transforms/InstCombine/lshr.ll (+2-3)
  • (modified) llvm/test/Transforms/InstCombine/minmax-fold.ll (+3-4)
  • (modified) llvm/test/Transforms/InstCombine/mul-inseltpoison.ll (+3-3)
  • (modified) llvm/test/Transforms/InstCombine/mul.ll (+15-16)
  • (modified) llvm/test/Transforms/InstCombine/opaque-ptr.ll (+2-3)
  • (modified) llvm/test/Transforms/InstCombine/or.ll (+5-9)
  • (modified) llvm/test/Transforms/InstCombine/saturating-add-sub.ll (+1-3)
  • (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-with-truncation-ashr.ll (+3-3)
  • (modified) llvm/test/Transforms/InstCombine/shift.ll (+13-14)
  • (modified) llvm/test/Transforms/InstCombine/shl-sub.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/sub.ll (+6-9)
  • (modified) llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll (+7-7)
  • (modified) llvm/test/Transforms/InstCombine/vector-casts-inseltpoison.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/vector-casts.ll (+1-2)
diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 778764062227cb..ec2ddb5564e73f 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -148,7 +148,8 @@ class Constant : public User {
   /// If all elements of the vector constant have the same value, return that
   /// value. Otherwise, return nullptr. Ignore undefined elements by setting
   /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  Constant *getSplatValue(bool AllowUndefs = false,
+                          bool AllowPoisons = false) const;
 
   /// If C is a constant integer then return its value, otherwise C must be a
   /// vector of constant integers, all equal, and the common value is returned.
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..4af0756de463da 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -534,7 +534,8 @@ class ConstantVector final : public ConstantAggregate {
   /// If all elements of the vector constant have the same value, return that
   /// value. Otherwise, return nullptr. Ignore undefined elements by setting
   /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  Constant *getSplatValue(bool AllowUndefs = false,
+                          bool AllowPoisons = false) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *V) {
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 98cc0e50376981..08ac9275463603 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -255,8 +255,8 @@ struct apint_match {
     }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI =
-                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) {
+        if (auto *CI = dyn_cast_or_null<ConstantInt>(
+                C->getSplatValue(AllowUndef, /*AllowPoisons=*/true))) {
           Res = &CI->getValue();
           return true;
         }
@@ -280,8 +280,8 @@ struct apfloat_match {
     }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI =
-                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) {
+        if (auto *CI = dyn_cast_or_null<ConstantFP>(
+                C->getSplatValue(AllowUndef, /*AllowPoisons=*/true))) {
           Res = &CI->getValueAPF();
           return true;
         }
@@ -353,7 +353,8 @@ struct cstval_pred_ty : public Predicate {
       return this->isValue(CV->getValue());
     if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
       if (const auto *C = dyn_cast<Constant>(V)) {
-        if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
+        if (const auto *CV = dyn_cast_or_null<ConstantVal>(
+                C->getSplatValue(/*AllowUndefs=*/false, /*AllowPoisons=*/true)))
           return this->isValue(CV->getValue());
 
         // Number of elements of a scalable vector unknown at compile time
@@ -406,7 +407,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+        if (auto *CI = dyn_cast_or_null<ConstantInt>(
+                C->getSplatValue(/*AllowUndefs=*/false, /*AllowPoisons=*/true)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -432,8 +434,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+        if (auto *CI = dyn_cast_or_null<ConstantFP>(C->getSplatValue(
+                /*AllowUndefs=*/true, /*AllowPoisons=*/true)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -892,7 +894,8 @@ template <bool AllowUndefs> struct specific_intval {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(
+            C->getSplatValue(AllowUndefs, /*AllowPoisons=*/true));
 
     return CI && APInt::isSameValue(CI->getValue(), Val);
   }
@@ -907,7 +910,8 @@ template <bool AllowUndefs> struct specific_intval64 {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(
+            C->getSplatValue(AllowUndefs, /*AllowPoisons=*/true));
 
     return CI && CI->getValue() == Val;
   }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 45b359a94b3ab7..28e291a59eb89a 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -715,7 +715,7 @@ static bool constantIsDead(const Constant *C, bool RemoveDeadUsers) {
     ReplaceableMetadataImpl::SalvageDebugInfo(*C);
     const_cast<Constant *>(C)->destroyConstant();
   }
-  
+
   return true;
 }
 
@@ -1696,14 +1696,14 @@ void ConstantVector::destroyConstantImpl() {
   getType()->getContext().pImpl->VectorConstants.remove(this);
 }
 
-Constant *Constant::getSplatValue(bool AllowUndefs) const {
+Constant *Constant::getSplatValue(bool AllowUndefs, bool AllowPoisons) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(cast<VectorType>(getType())->getElementType());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
-    return CV->getSplatValue(AllowUndefs);
+    return CV->getSplatValue(AllowUndefs, AllowPoisons);
 
   // Check if this is a constant expression splat of the form returned by
   // ConstantVector::getSplat()
@@ -1728,7 +1728,8 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
   return nullptr;
 }
 
-Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
+Constant *ConstantVector::getSplatValue(bool AllowUndefs,
+                                        bool AllowPoisons) const {
   // Check out first element.
   Constant *Elt = getOperand(0);
   // Then make sure all remaining elements point to the same value.
@@ -1737,16 +1738,28 @@ Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
     if (OpC == Elt)
       continue;
 
-    // Strict mode: any mismatch is not a splat.
-    if (!AllowUndefs)
+    if (!AllowPoisons && !AllowUndefs)
       return nullptr;
 
-    // Allow undefs mode: ignore undefined elements.
-    if (isa<UndefValue>(OpC))
+    if (isa<PoisonValue>(OpC)) {
+      assert(isa<UndefValue>(OpC));
+      // Strict mode: any mismatch is not a splat.
+      if (!AllowPoisons && !AllowUndefs)
+        return nullptr;
+      // Allow poisons mode: ignore poison elements.
+      continue;
+    } else if (isa<UndefValue>(OpC)) {
+      // Strict mode: any mismatch is not a splat.
+      if (!AllowUndefs)
+        return nullptr;
+      // Allow undefs/poisons mode: ignore undefined elements.
       continue;
+    }
 
     // If we do not have a defined element yet, use the current operand.
-    if (isa<UndefValue>(Elt))
+    if (AllowPoisons && isa<PoisonValue>(Elt))
+      Elt = OpC;
+    else if (AllowUndefs && isa<UndefValue>(Elt))
       Elt = OpC;
 
     if (OpC != Elt)
diff --git a/llvm/test/Analysis/ValueTracking/known-non-zero.ll b/llvm/test/Analysis/ValueTracking/known-non-zero.ll
index 0159050d925c3e..0ee09a131658de 100644
--- a/llvm/test/Analysis/ValueTracking/known-non-zero.ll
+++ b/llvm/test/Analysis/ValueTracking/known-non-zero.ll
@@ -1189,11 +1189,7 @@ define <2 x i1> @cmp_excludes_zero_with_nonsplat_vec_wundef(<2 x i8> %a, <2 x i8
 
 define <2 x i1> @cmp_excludes_zero_with_nonsplat_vec_wpoison(<2 x i8> %a, <2 x i8> %b) {
 ; CHECK-LABEL: @cmp_excludes_zero_with_nonsplat_vec_wpoison(
-; CHECK-NEXT:    [[C:%.*]] = icmp sge <2 x i8> [[A:%.*]], <i8 1, i8 poison>
-; CHECK-NEXT:    [[S:%.*]] = select <2 x i1> [[C]], <2 x i8> [[A]], <2 x i8> <i8 4, i8 5>
-; CHECK-NEXT:    [[AND:%.*]] = or <2 x i8> [[S]], [[B:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq <2 x i8> [[AND]], zeroinitializer
-; CHECK-NEXT:    ret <2 x i1> [[R]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %c = icmp sge <2 x i8> %a, <i8 1, i8 poison>
   %s = select <2 x i1> %c, <2 x i8> %a, <2 x i8> <i8 4, i8 5>
@@ -1314,8 +1310,8 @@ define i1 @range_attr(i8 range(i8 1, 0) %x, i8 %y) {
 
 define i1 @neg_range_attr(i8 range(i8 -1, 1) %x, i8 %y) {
 ; CHECK-LABEL: @neg_range_attr(
-; CHECK-NEXT:    [[I:%.*]] = or i8 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[I]], 0
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %or = or i8 %y, %x
@@ -1328,7 +1324,7 @@ declare range(i8 -1, 1) i8 @returns_contain_zero_range_helper()
 
 define i1 @range_return(i8 %y) {
 ; CHECK-LABEL: @range_return(
-; CHECK-NEXT:    [[I:%.*]] = call i8 @returns_non_zero_range_helper()
+; CHECK-NEXT:    [[X:%.*]] = call i8 @returns_non_zero_range_helper()
 ; CHECK-NEXT:    ret i1 false
 ;
   %x = call i8 @returns_non_zero_range_helper()
@@ -1339,8 +1335,8 @@ define i1 @range_return(i8 %y) {
 
 define i1 @neg_range_return(i8 %y) {
 ; CHECK-LABEL: @neg_range_return(
-; CHECK-NEXT:    [[I:%.*]] = call i8 @returns_contain_zero_range_helper()
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call i8 @returns_contain_zero_range_helper()
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
@@ -1354,7 +1350,7 @@ declare i8 @returns_i8_helper()
 
 define i1 @range_call(i8 %y) {
 ; CHECK-LABEL: @range_call(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 1, 0) i8 @returns_i8_helper()
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 1, 0) i8 @returns_i8_helper()
 ; CHECK-NEXT:    ret i1 false
 ;
   %x = call range(i8 1, 0) i8 @returns_i8_helper()
@@ -1365,8 +1361,8 @@ define i1 @range_call(i8 %y) {
 
 define i1 @neg_range_call(i8 %y) {
 ; CHECK-LABEL: @neg_range_call(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 -1, 1) i8 @returns_i8_helper()
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 -1, 1) i8 @returns_i8_helper()
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
@@ -1401,7 +1397,7 @@ declare range(i8 -1, 1) <2 x i8> @returns_contain_zero_range_helper_vec()
 
 define <2 x i1> @range_return_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @range_return_vec(
-; CHECK-NEXT:    [[I:%.*]] = call <2 x i8> @returns_non_zero_range_helper_vec()
+; CHECK-NEXT:    [[X:%.*]] = call <2 x i8> @returns_non_zero_range_helper_vec()
 ; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %x = call <2 x i8> @returns_non_zero_range_helper_vec()
@@ -1412,8 +1408,8 @@ define <2 x i1> @range_return_vec(<2 x i8> %y) {
 
 define <2 x i1> @neg_range_return_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @neg_range_return_vec(
-; CHECK-NEXT:    [[I:%.*]] = call <2 x i8> @returns_contain_zero_range_helper_vec()
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call <2 x i8> @returns_contain_zero_range_helper_vec()
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
@@ -1427,7 +1423,7 @@ declare <2 x i8> @returns_i8_helper_vec()
 
 define <2 x i1> @range_call_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @range_call_vec(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
 ; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %x = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
@@ -1438,8 +1434,8 @@ define <2 x i1> @range_call_vec(<2 x i8> %y) {
 
 define <2 x i1> @neg_range_call_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @neg_range_call_vec(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 39b4ad80550889..28d5b0bedc082f 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -1198,8 +1198,7 @@ define <2 x i32> @test44_vec_non_matching(<2 x i32> %A) {
 
 define <2 x i32> @test44_vec_poison(<2 x i32> %A) {
 ; CHECK-LABEL: @test44_vec_poison(
-; CHECK-NEXT:    [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 poison>
-; CHECK-NEXT:    [[C:%.*]] = add nsw <2 x i32> [[B]], <i32 -123, i32 poison>
+; CHECK-NEXT:    [[C:%.*]] = and <2 x i32> [[A:%.*]], <i32 -124, i32 -124>
 ; CHECK-NEXT:    ret <2 x i32> [[C]]
 ;
   %B = or <2 x i32> %A, <i32 123, i32 poison>
@@ -3139,9 +3138,7 @@ define <2 x i32> @dec_zext_add_nonzero_vec_poison1(<2 x i8> %x) {
 define <2 x i32> @dec_zext_add_nonzero_vec_poison2(<2 x i8> %x) {
 ; CHECK-LABEL: @dec_zext_add_nonzero_vec_poison2(
 ; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 8, i8 8>
-; CHECK-NEXT:    [[A:%.*]] = add nsw <2 x i8> [[O]], <i8 -1, i8 -1>
-; CHECK-NEXT:    [[B:%.*]] = zext <2 x i8> [[A]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = add nuw nsw <2 x i32> [[B]], <i32 1, i32 poison>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i8> [[O]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[C]]
 ;
   %o = or <2 x i8> %x, <i8 8, i8 8>
@@ -4018,8 +4015,8 @@ define i32 @add_reduce_sqr_sum_varC_invalid2(i32 %a, i32 %b) {
 
 define i32 @fold_sext_addition_or_disjoint(i8 %x) {
 ; CHECK-LABEL: @fold_sext_addition_or_disjoint(
-; CHECK-NEXT:    [[SE:%.*]] = sext i8 [[XX:%.*]] to i32
-; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[SE]], 1246
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[TMP1]], 1246
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 12
@@ -4043,8 +4040,8 @@ define i32 @fold_sext_addition_fail(i8 %x) {
 
 define i32 @fold_zext_addition_or_disjoint(i8 %x) {
 ; CHECK-LABEL: @fold_zext_addition_or_disjoint(
-; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX:%.*]] to i32
-; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[SE]], 1246
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[TMP1]], 1246
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 12
@@ -4055,9 +4052,9 @@ define i32 @fold_zext_addition_or_disjoint(i8 %x) {
 
 define i32 @fold_zext_addition_or_disjoint2(i8 %x) {
 ; CHECK-LABEL: @fold_zext_addition_or_disjoint2(
-; CHECK-NEXT:    [[XX:%.*]] = add nuw i8 [[X:%.*]], 4
-; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX]] to i32
-; CHECK-NEXT:    ret i32 [[SE]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i8 [[X:%.*]], 4
+; CHECK-NEXT:    [[R:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 18
   %se = zext i8 %xx to i32
diff --git a/llvm/test/Transforms/InstCombine/and.ll b/llvm/test/Transforms/InstCombine/and.ll
index b5250fc1a7849d..738a4a6a4cfbfc 100644
--- a/llvm/test/Transforms/InstCombine/and.ll
+++ b/llvm/test/Transforms/InstCombine/and.ll
@@ -754,9 +754,9 @@ define <2 x i64> @test36_uniform(<2 x i32> %X) {
 
 define <2 x i64> @test36_poison(<2 x i32> %X) {
 ; CHECK-LABEL: @test36_poison(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[X:%.*]] to <2 x i64>
-; CHECK-NEXT:    [[ZSUB:%.*]] = add nuw nsw <2 x i64> [[ZEXT]], <i64 7, i64 poison>
-; CHECK-NEXT:    [[RES:%.*]] = and <2 x i64> [[ZSUB]], <i64 240, i64 poison>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 240, i32 240>
+; CHECK-NEXT:    [[RES:%.*]] = zext nneg <2 x i32> [[TMP2]] to <2 x i64>
 ; CHECK-NEXT:    ret <2 x i64> [[RES]]
 ;
   %zext = zext <2 x i32> %X to <2 x i64>
@@ -1681,8 +1681,8 @@ define <2 x i8> @flip_masked_bit_uniform(<2 x i8> %A) {
 
 define <2 x i8> @flip_masked_bit_poison(<2 x i8> %A) {
 ; CHECK-LABEL: @flip_masked_bit_poison(
-; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i8> [[A:%.*]], <i8 -1, i8 -1>
-; CHECK-NEXT:    [[C:%.*]] = and <2 x i8> [[TMP1]], <i8 16, i8 poison>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i8> [[A:%.*]], <i8 16, i8 poison>
+; CHECK-NEXT:    [[C:%.*]] = xor <2 x i8> [[TMP1]], <i8 16, i8 16>
 ; CHECK-NEXT:    ret <2 x i8> [[C]]
 ;
   %B = add <2 x i8> %A, <i8 16, i8 poison>
@@ -1960,8 +1960,8 @@ define i16 @invert_signbit_splat_mask(i8 %x, i16 %y) {
 define <2 x i16> @invert_signbit_splat_mask_commute(<2 x i5> %x, <2 x i16> %p) {
 ; CHECK-LABEL: @invert_signbit_splat_mask_commute(
 ; CHECK-NEXT:    [[Y:%.*]] = mul <2 x i16> [[P:%.*]], [[P]]
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt <2 x i5> [[X:%.*]], zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[ISNEG]], <2 x i16> zeroinitializer, <2 x i16> [[Y]]
+; CHECK-NEXT:    [[ISNOTNEG:%.*]] = icmp sgt <2 x i5> [[X:%.*]], <i5 -1, i5 -1>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[ISNOTNEG]], <2 x i16> [[Y]], <2 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <2 x i16> [[R]]
 ;
   %y = mul <2 x i16> %p, %p ; thwart complexity-based canonicalization
@@ -2122,7 +2122,7 @@ define <3 x i16> @shl_lshr_pow2_const_case1_non_uniform_vec_negative(<3 x i16> %
 
 define <3 x i16> @shl_lshr_pow2_const_case1_poison1_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @shl_lshr_pow2_const_case1_poison1_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 8, i16 4, i16 4>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 4, i16 4, i16 4>
 ; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 8, i16 8, i16 8>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
@@ -2146,9 +2146,8 @@ define <3 x i16> @shl_lshr_pow2_const_case1_poison2_vec(<3 x i16> %x) {
 
 define <3 x i16> @shl_lshr_pow2_const_case1_poison3_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @shl_lshr_pow2_const_case1_poison3_vec(
-; CHECK-NEXT:    [[SHL:%.*]] = shl <3 x i16> <i16 16, i16 16, i16 16>, [[X:%.*]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr <3 x i16> [[SHL]], <i16 5, i16 5, i16 5>
-; CHECK-NEXT:    [[R:%.*]] = and <3 x i16> [[LSHR]], <i16 poison, i16 8, i16 8>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 4, i16 4, i16 4>
+; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 8, i16 8, i16 8>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
   %shl = shl <3 x i16> <i16 16, i16 16, i16 16>, %x
@@ -2418,7 +2417,7 @@ define <3 x i16> @lshr_shl_pow2_const_case1_non_uniform_vec_negative(<3 x i16> %
 
 define <3 x i16> @lshr_shl_pow2_const_case1_poison1_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @lshr_shl_pow2_const_case1_poison1_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 -1, i16 12, i16 12>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 12, i16 12, i16 12>
 ; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 128, i16 128, i16 128>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
@@ -2443,9 +2442,8 @@ define <3 x i16> @lshr_shl_pow2_const_case1_poison2_vec(<3 x i16> %x) {
 
 define <3 x i16> @lshr_shl_pow2_const_case1_poison3_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @lshr_shl_pow2_const_case1_poison3_vec(
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, [[X:%.*]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl <3 x i16> [[LSHR]], <i16 6, i16 6, i16 6>
-; CHECK-NEXT: ...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

Changes

This is a followup to #88217


Patch is 59.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89155.diff

27 Files Affected:

  • (modified) llvm/include/llvm/IR/Constant.h (+2-1)
  • (modified) llvm/include/llvm/IR/Constants.h (+2-1)
  • (modified) llvm/include/llvm/IR/PatternMatch.h (+14-10)
  • (modified) llvm/lib/IR/Constants.cpp (+22-9)
  • (modified) llvm/test/Analysis/ValueTracking/known-non-zero.ll (+15-19)
  • (modified) llvm/test/Transforms/InstCombine/add.ll (+9-12)
  • (modified) llvm/test/Transforms/InstCombine/and.ll (+13-15)
  • (modified) llvm/test/Transforms/InstCombine/ashr-lshr.ll (+4-4)
  • (modified) llvm/test/Transforms/InstCombine/cast.ll (+6-6)
  • (modified) llvm/test/Transforms/InstCombine/getelementptr.ll (+6-8)
  • (modified) llvm/test/Transforms/InstCombine/icmp-logical.ll (+7-11)
  • (modified) llvm/test/Transforms/InstCombine/icmp-range.ll (+15-17)
  • (modified) llvm/test/Transforms/InstCombine/icmp-shr.ll (+4-6)
  • (modified) llvm/test/Transforms/InstCombine/lshr.ll (+2-3)
  • (modified) llvm/test/Transforms/InstCombine/minmax-fold.ll (+3-4)
  • (modified) llvm/test/Transforms/InstCombine/mul-inseltpoison.ll (+3-3)
  • (modified) llvm/test/Transforms/InstCombine/mul.ll (+15-16)
  • (modified) llvm/test/Transforms/InstCombine/opaque-ptr.ll (+2-3)
  • (modified) llvm/test/Transforms/InstCombine/or.ll (+5-9)
  • (modified) llvm/test/Transforms/InstCombine/saturating-add-sub.ll (+1-3)
  • (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-with-truncation-ashr.ll (+3-3)
  • (modified) llvm/test/Transforms/InstCombine/shift.ll (+13-14)
  • (modified) llvm/test/Transforms/InstCombine/shl-sub.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/sub.ll (+6-9)
  • (modified) llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll (+7-7)
  • (modified) llvm/test/Transforms/InstCombine/vector-casts-inseltpoison.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/vector-casts.ll (+1-2)
diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 778764062227cb..ec2ddb5564e73f 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -148,7 +148,8 @@ class Constant : public User {
   /// If all elements of the vector constant have the same value, return that
   /// value. Otherwise, return nullptr. Ignore undefined elements by setting
   /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  Constant *getSplatValue(bool AllowUndefs = false,
+                          bool AllowPoisons = false) const;
 
   /// If C is a constant integer then return its value, otherwise C must be a
   /// vector of constant integers, all equal, and the common value is returned.
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..4af0756de463da 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -534,7 +534,8 @@ class ConstantVector final : public ConstantAggregate {
   /// If all elements of the vector constant have the same value, return that
   /// value. Otherwise, return nullptr. Ignore undefined elements by setting
   /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  Constant *getSplatValue(bool AllowUndefs = false,
+                          bool AllowPoisons = false) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *V) {
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 98cc0e50376981..08ac9275463603 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -255,8 +255,8 @@ struct apint_match {
     }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI =
-                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) {
+        if (auto *CI = dyn_cast_or_null<ConstantInt>(
+                C->getSplatValue(AllowUndef, /*AllowPoisons=*/true))) {
           Res = &CI->getValue();
           return true;
         }
@@ -280,8 +280,8 @@ struct apfloat_match {
     }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI =
-                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) {
+        if (auto *CI = dyn_cast_or_null<ConstantFP>(
+                C->getSplatValue(AllowUndef, /*AllowPoisons=*/true))) {
           Res = &CI->getValueAPF();
           return true;
         }
@@ -353,7 +353,8 @@ struct cstval_pred_ty : public Predicate {
       return this->isValue(CV->getValue());
     if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
       if (const auto *C = dyn_cast<Constant>(V)) {
-        if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
+        if (const auto *CV = dyn_cast_or_null<ConstantVal>(
+                C->getSplatValue(/*AllowUndefs=*/false, /*AllowPoisons=*/true)))
           return this->isValue(CV->getValue());
 
         // Number of elements of a scalable vector unknown at compile time
@@ -406,7 +407,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+        if (auto *CI = dyn_cast_or_null<ConstantInt>(
+                C->getSplatValue(/*AllowUndefs=*/false, /*AllowPoisons=*/true)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -432,8 +434,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+        if (auto *CI = dyn_cast_or_null<ConstantFP>(C->getSplatValue(
+                /*AllowUndefs=*/true, /*AllowPoisons=*/true)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -892,7 +894,8 @@ template <bool AllowUndefs> struct specific_intval {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(
+            C->getSplatValue(AllowUndefs, /*AllowPoisons=*/true));
 
     return CI && APInt::isSameValue(CI->getValue(), Val);
   }
@@ -907,7 +910,8 @@ template <bool AllowUndefs> struct specific_intval64 {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(
+            C->getSplatValue(AllowUndefs, /*AllowPoisons=*/true));
 
     return CI && CI->getValue() == Val;
   }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 45b359a94b3ab7..28e291a59eb89a 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -715,7 +715,7 @@ static bool constantIsDead(const Constant *C, bool RemoveDeadUsers) {
     ReplaceableMetadataImpl::SalvageDebugInfo(*C);
     const_cast<Constant *>(C)->destroyConstant();
   }
-  
+
   return true;
 }
 
@@ -1696,14 +1696,14 @@ void ConstantVector::destroyConstantImpl() {
   getType()->getContext().pImpl->VectorConstants.remove(this);
 }
 
-Constant *Constant::getSplatValue(bool AllowUndefs) const {
+Constant *Constant::getSplatValue(bool AllowUndefs, bool AllowPoisons) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(cast<VectorType>(getType())->getElementType());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
-    return CV->getSplatValue(AllowUndefs);
+    return CV->getSplatValue(AllowUndefs, AllowPoisons);
 
   // Check if this is a constant expression splat of the form returned by
   // ConstantVector::getSplat()
@@ -1728,7 +1728,8 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
   return nullptr;
 }
 
-Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
+Constant *ConstantVector::getSplatValue(bool AllowUndefs,
+                                        bool AllowPoisons) const {
   // Check out first element.
   Constant *Elt = getOperand(0);
   // Then make sure all remaining elements point to the same value.
@@ -1737,16 +1738,28 @@ Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
     if (OpC == Elt)
       continue;
 
-    // Strict mode: any mismatch is not a splat.
-    if (!AllowUndefs)
+    if (!AllowPoisons && !AllowUndefs)
       return nullptr;
 
-    // Allow undefs mode: ignore undefined elements.
-    if (isa<UndefValue>(OpC))
+    if (isa<PoisonValue>(OpC)) {
+      assert(isa<UndefValue>(OpC));
+      // Strict mode: any mismatch is not a splat.
+      if (!AllowPoisons && !AllowUndefs)
+        return nullptr;
+      // Allow poisons mode: ignore poison elements.
+      continue;
+    } else if (isa<UndefValue>(OpC)) {
+      // Strict mode: any mismatch is not a splat.
+      if (!AllowUndefs)
+        return nullptr;
+      // Allow undefs/poisons mode: ignore undefined elements.
       continue;
+    }
 
     // If we do not have a defined element yet, use the current operand.
-    if (isa<UndefValue>(Elt))
+    if (AllowPoisons && isa<PoisonValue>(Elt))
+      Elt = OpC;
+    else if (AllowUndefs && isa<UndefValue>(Elt))
       Elt = OpC;
 
     if (OpC != Elt)
diff --git a/llvm/test/Analysis/ValueTracking/known-non-zero.ll b/llvm/test/Analysis/ValueTracking/known-non-zero.ll
index 0159050d925c3e..0ee09a131658de 100644
--- a/llvm/test/Analysis/ValueTracking/known-non-zero.ll
+++ b/llvm/test/Analysis/ValueTracking/known-non-zero.ll
@@ -1189,11 +1189,7 @@ define <2 x i1> @cmp_excludes_zero_with_nonsplat_vec_wundef(<2 x i8> %a, <2 x i8
 
 define <2 x i1> @cmp_excludes_zero_with_nonsplat_vec_wpoison(<2 x i8> %a, <2 x i8> %b) {
 ; CHECK-LABEL: @cmp_excludes_zero_with_nonsplat_vec_wpoison(
-; CHECK-NEXT:    [[C:%.*]] = icmp sge <2 x i8> [[A:%.*]], <i8 1, i8 poison>
-; CHECK-NEXT:    [[S:%.*]] = select <2 x i1> [[C]], <2 x i8> [[A]], <2 x i8> <i8 4, i8 5>
-; CHECK-NEXT:    [[AND:%.*]] = or <2 x i8> [[S]], [[B:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq <2 x i8> [[AND]], zeroinitializer
-; CHECK-NEXT:    ret <2 x i1> [[R]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %c = icmp sge <2 x i8> %a, <i8 1, i8 poison>
   %s = select <2 x i1> %c, <2 x i8> %a, <2 x i8> <i8 4, i8 5>
@@ -1314,8 +1310,8 @@ define i1 @range_attr(i8 range(i8 1, 0) %x, i8 %y) {
 
 define i1 @neg_range_attr(i8 range(i8 -1, 1) %x, i8 %y) {
 ; CHECK-LABEL: @neg_range_attr(
-; CHECK-NEXT:    [[I:%.*]] = or i8 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[I]], 0
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %or = or i8 %y, %x
@@ -1328,7 +1324,7 @@ declare range(i8 -1, 1) i8 @returns_contain_zero_range_helper()
 
 define i1 @range_return(i8 %y) {
 ; CHECK-LABEL: @range_return(
-; CHECK-NEXT:    [[I:%.*]] = call i8 @returns_non_zero_range_helper()
+; CHECK-NEXT:    [[X:%.*]] = call i8 @returns_non_zero_range_helper()
 ; CHECK-NEXT:    ret i1 false
 ;
   %x = call i8 @returns_non_zero_range_helper()
@@ -1339,8 +1335,8 @@ define i1 @range_return(i8 %y) {
 
 define i1 @neg_range_return(i8 %y) {
 ; CHECK-LABEL: @neg_range_return(
-; CHECK-NEXT:    [[I:%.*]] = call i8 @returns_contain_zero_range_helper()
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call i8 @returns_contain_zero_range_helper()
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
@@ -1354,7 +1350,7 @@ declare i8 @returns_i8_helper()
 
 define i1 @range_call(i8 %y) {
 ; CHECK-LABEL: @range_call(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 1, 0) i8 @returns_i8_helper()
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 1, 0) i8 @returns_i8_helper()
 ; CHECK-NEXT:    ret i1 false
 ;
   %x = call range(i8 1, 0) i8 @returns_i8_helper()
@@ -1365,8 +1361,8 @@ define i1 @range_call(i8 %y) {
 
 define i1 @neg_range_call(i8 %y) {
 ; CHECK-LABEL: @neg_range_call(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 -1, 1) i8 @returns_i8_helper()
-; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 -1, 1) i8 @returns_i8_helper()
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[OR]], 0
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
@@ -1401,7 +1397,7 @@ declare range(i8 -1, 1) <2 x i8> @returns_contain_zero_range_helper_vec()
 
 define <2 x i1> @range_return_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @range_return_vec(
-; CHECK-NEXT:    [[I:%.*]] = call <2 x i8> @returns_non_zero_range_helper_vec()
+; CHECK-NEXT:    [[X:%.*]] = call <2 x i8> @returns_non_zero_range_helper_vec()
 ; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %x = call <2 x i8> @returns_non_zero_range_helper_vec()
@@ -1412,8 +1408,8 @@ define <2 x i1> @range_return_vec(<2 x i8> %y) {
 
 define <2 x i1> @neg_range_return_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @neg_range_return_vec(
-; CHECK-NEXT:    [[I:%.*]] = call <2 x i8> @returns_contain_zero_range_helper_vec()
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call <2 x i8> @returns_contain_zero_range_helper_vec()
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
@@ -1427,7 +1423,7 @@ declare <2 x i8> @returns_i8_helper_vec()
 
 define <2 x i1> @range_call_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @range_call_vec(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
 ; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %x = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
@@ -1438,8 +1434,8 @@ define <2 x i1> @range_call_vec(<2 x i8> %y) {
 
 define <2 x i1> @neg_range_call_vec(<2 x i8> %y) {
 ; CHECK-LABEL: @neg_range_call_vec(
-; CHECK-NEXT:    [[I:%.*]] = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
+; CHECK-NEXT:    [[X:%.*]] = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[X]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 39b4ad80550889..28d5b0bedc082f 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -1198,8 +1198,7 @@ define <2 x i32> @test44_vec_non_matching(<2 x i32> %A) {
 
 define <2 x i32> @test44_vec_poison(<2 x i32> %A) {
 ; CHECK-LABEL: @test44_vec_poison(
-; CHECK-NEXT:    [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 poison>
-; CHECK-NEXT:    [[C:%.*]] = add nsw <2 x i32> [[B]], <i32 -123, i32 poison>
+; CHECK-NEXT:    [[C:%.*]] = and <2 x i32> [[A:%.*]], <i32 -124, i32 -124>
 ; CHECK-NEXT:    ret <2 x i32> [[C]]
 ;
   %B = or <2 x i32> %A, <i32 123, i32 poison>
@@ -3139,9 +3138,7 @@ define <2 x i32> @dec_zext_add_nonzero_vec_poison1(<2 x i8> %x) {
 define <2 x i32> @dec_zext_add_nonzero_vec_poison2(<2 x i8> %x) {
 ; CHECK-LABEL: @dec_zext_add_nonzero_vec_poison2(
 ; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 8, i8 8>
-; CHECK-NEXT:    [[A:%.*]] = add nsw <2 x i8> [[O]], <i8 -1, i8 -1>
-; CHECK-NEXT:    [[B:%.*]] = zext <2 x i8> [[A]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = add nuw nsw <2 x i32> [[B]], <i32 1, i32 poison>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i8> [[O]] to <2 x i32>
 ; CHECK-NEXT:    ret <2 x i32> [[C]]
 ;
   %o = or <2 x i8> %x, <i8 8, i8 8>
@@ -4018,8 +4015,8 @@ define i32 @add_reduce_sqr_sum_varC_invalid2(i32 %a, i32 %b) {
 
 define i32 @fold_sext_addition_or_disjoint(i8 %x) {
 ; CHECK-LABEL: @fold_sext_addition_or_disjoint(
-; CHECK-NEXT:    [[SE:%.*]] = sext i8 [[XX:%.*]] to i32
-; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[SE]], 1246
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[TMP1]], 1246
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 12
@@ -4043,8 +4040,8 @@ define i32 @fold_sext_addition_fail(i8 %x) {
 
 define i32 @fold_zext_addition_or_disjoint(i8 %x) {
 ; CHECK-LABEL: @fold_zext_addition_or_disjoint(
-; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX:%.*]] to i32
-; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[SE]], 1246
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[TMP1]], 1246
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 12
@@ -4055,9 +4052,9 @@ define i32 @fold_zext_addition_or_disjoint(i8 %x) {
 
 define i32 @fold_zext_addition_or_disjoint2(i8 %x) {
 ; CHECK-LABEL: @fold_zext_addition_or_disjoint2(
-; CHECK-NEXT:    [[XX:%.*]] = add nuw i8 [[X:%.*]], 4
-; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX]] to i32
-; CHECK-NEXT:    ret i32 [[SE]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i8 [[X:%.*]], 4
+; CHECK-NEXT:    [[R:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[R]]
 ;
   %xx = or disjoint i8 %x, 18
   %se = zext i8 %xx to i32
diff --git a/llvm/test/Transforms/InstCombine/and.ll b/llvm/test/Transforms/InstCombine/and.ll
index b5250fc1a7849d..738a4a6a4cfbfc 100644
--- a/llvm/test/Transforms/InstCombine/and.ll
+++ b/llvm/test/Transforms/InstCombine/and.ll
@@ -754,9 +754,9 @@ define <2 x i64> @test36_uniform(<2 x i32> %X) {
 
 define <2 x i64> @test36_poison(<2 x i32> %X) {
 ; CHECK-LABEL: @test36_poison(
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[X:%.*]] to <2 x i64>
-; CHECK-NEXT:    [[ZSUB:%.*]] = add nuw nsw <2 x i64> [[ZEXT]], <i64 7, i64 poison>
-; CHECK-NEXT:    [[RES:%.*]] = and <2 x i64> [[ZSUB]], <i64 240, i64 poison>
+; CHECK-NEXT:    [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 240, i32 240>
+; CHECK-NEXT:    [[RES:%.*]] = zext nneg <2 x i32> [[TMP2]] to <2 x i64>
 ; CHECK-NEXT:    ret <2 x i64> [[RES]]
 ;
   %zext = zext <2 x i32> %X to <2 x i64>
@@ -1681,8 +1681,8 @@ define <2 x i8> @flip_masked_bit_uniform(<2 x i8> %A) {
 
 define <2 x i8> @flip_masked_bit_poison(<2 x i8> %A) {
 ; CHECK-LABEL: @flip_masked_bit_poison(
-; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i8> [[A:%.*]], <i8 -1, i8 -1>
-; CHECK-NEXT:    [[C:%.*]] = and <2 x i8> [[TMP1]], <i8 16, i8 poison>
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i8> [[A:%.*]], <i8 16, i8 poison>
+; CHECK-NEXT:    [[C:%.*]] = xor <2 x i8> [[TMP1]], <i8 16, i8 16>
 ; CHECK-NEXT:    ret <2 x i8> [[C]]
 ;
   %B = add <2 x i8> %A, <i8 16, i8 poison>
@@ -1960,8 +1960,8 @@ define i16 @invert_signbit_splat_mask(i8 %x, i16 %y) {
 define <2 x i16> @invert_signbit_splat_mask_commute(<2 x i5> %x, <2 x i16> %p) {
 ; CHECK-LABEL: @invert_signbit_splat_mask_commute(
 ; CHECK-NEXT:    [[Y:%.*]] = mul <2 x i16> [[P:%.*]], [[P]]
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt <2 x i5> [[X:%.*]], zeroinitializer
-; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[ISNEG]], <2 x i16> zeroinitializer, <2 x i16> [[Y]]
+; CHECK-NEXT:    [[ISNOTNEG:%.*]] = icmp sgt <2 x i5> [[X:%.*]], <i5 -1, i5 -1>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[ISNOTNEG]], <2 x i16> [[Y]], <2 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <2 x i16> [[R]]
 ;
   %y = mul <2 x i16> %p, %p ; thwart complexity-based canonicalization
@@ -2122,7 +2122,7 @@ define <3 x i16> @shl_lshr_pow2_const_case1_non_uniform_vec_negative(<3 x i16> %
 
 define <3 x i16> @shl_lshr_pow2_const_case1_poison1_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @shl_lshr_pow2_const_case1_poison1_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 8, i16 4, i16 4>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 4, i16 4, i16 4>
 ; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 8, i16 8, i16 8>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
@@ -2146,9 +2146,8 @@ define <3 x i16> @shl_lshr_pow2_const_case1_poison2_vec(<3 x i16> %x) {
 
 define <3 x i16> @shl_lshr_pow2_const_case1_poison3_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @shl_lshr_pow2_const_case1_poison3_vec(
-; CHECK-NEXT:    [[SHL:%.*]] = shl <3 x i16> <i16 16, i16 16, i16 16>, [[X:%.*]]
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr <3 x i16> [[SHL]], <i16 5, i16 5, i16 5>
-; CHECK-NEXT:    [[R:%.*]] = and <3 x i16> [[LSHR]], <i16 poison, i16 8, i16 8>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 4, i16 4, i16 4>
+; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 8, i16 8, i16 8>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
   %shl = shl <3 x i16> <i16 16, i16 16, i16 16>, %x
@@ -2418,7 +2417,7 @@ define <3 x i16> @lshr_shl_pow2_const_case1_non_uniform_vec_negative(<3 x i16> %
 
 define <3 x i16> @lshr_shl_pow2_const_case1_poison1_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @lshr_shl_pow2_const_case1_poison1_vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 -1, i16 12, i16 12>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 12, i16 12, i16 12>
 ; CHECK-NEXT:    [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 128, i16 128, i16 128>, <3 x i16> zeroinitializer
 ; CHECK-NEXT:    ret <3 x i16> [[R]]
 ;
@@ -2443,9 +2442,8 @@ define <3 x i16> @lshr_shl_pow2_const_case1_poison2_vec(<3 x i16> %x) {
 
 define <3 x i16> @lshr_shl_pow2_const_case1_poison3_vec(<3 x i16> %x) {
 ; CHECK-LABEL: @lshr_shl_pow2_const_case1_poison3_vec(
-; CHECK-NEXT:    [[LSHR:%.*]] = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, [[X:%.*]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl <3 x i16> [[LSHR]], <i16 6, i16 6, i16 6>
-; CHECK-NEXT: ...
[truncated]

@nikic
Copy link
Contributor

nikic commented Apr 18, 2024

I don't think this is the change we want to do. I've posted an alternative at #89159.

Basically:

  • In the spirit of [PatternMatch] Do not accept undef elements in m_AllOnes() and friends #88217, we should remove support for undef splats to the degree possible, not support both.
  • Allowing poison splats by default in m_APInt/m_APFloat is a separate change, and requires at least some due diligence (at least a cursory review of select combines, where this is most likely to cause issues.)

@goldsteinn
Copy link
Contributor Author

I don't think this is the change we want to do. I've posted an alternative at #89159.

Basically:

  • In the spirit of [PatternMatch] Do not accept undef elements in m_AllOnes() and friends #88217, we should remove support for undef splats to the degree possible, not support both.
  • Allowing poison splats by default in m_APInt/m_APFloat is a separate change, and requires at least some due diligence (at least a cursory review of select combines, where this is most likely to cause issues.)

Yeah I'm happy with that. The goal was poison support in getsplat, don't have any particular attachment/reason for keeping the undef stuff.

@goldsteinn goldsteinn closed this May 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants