Skip to content

Commit

Permalink
[PatternMatch] Do not accept undef elements in m_AllOnes() and friends (
Browse files Browse the repository at this point in the history
#88217)

Change all the cstval_pred_ty based PatternMatch helpers (things like
m_AllOnes and m_Zero) to only allow poison elements inside vector
splats, not undef elements.

Historically, we used to represent non-demanded elements in vectors
using undef. Nowadays, we use poison instead. As such, I believe that
support for undef in vector splats is no longer useful.

At the same time, while poison splat elements are pretty much always
safe to ignore, this is not generally the case for undef elements. We
have existing miscompiles in our tests due to this (see the
masked-merge-*.ll tests changed here) and it's easy to miss such cases
in the future, now that we write tests using poison instead of undef
elements.

I think overall, keeping support for undef elements no longer makes
sense, and we should drop it. Once this is done consistently, I think we
may also consider allowing poison in m_APInt by default, as doing that
change is much less risky than doing the same with undef.

This change involves a substantial amount of test changes. For most
tests, I've just replaced undef with poison, as I don't think there is
value in retaining both. For some tests (where the distinction between
undef and poison is important), I've duplicated tests.
  • Loading branch information
nikic committed Apr 17, 2024
1 parent a16bb07 commit d9a5aa8
Show file tree
Hide file tree
Showing 158 changed files with 2,042 additions and 1,839 deletions.
35 changes: 5 additions & 30 deletions llvm/include/llvm/IR/PatternMatch.h
Expand Up @@ -345,7 +345,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {

/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
/// For fixed width vector constants, poison elements are ignored.
template <typename Predicate, typename ConstantVal>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
Expand All @@ -364,19 +364,19 @@ struct cstval_pred_ty : public Predicate {
// Non-splat vector constant: check each element for a match.
unsigned NumElts = FVTy->getNumElements();
assert(NumElts != 0 && "Constant vector with no elements?");
bool HasNonUndefElements = false;
bool HasNonPoisonElements = false;
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
if (isa<UndefValue>(Elt))
if (isa<PoisonValue>(Elt))
continue;
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
HasNonUndefElements = true;
HasNonPoisonElements = true;
}
return HasNonUndefElements;
return HasNonPoisonElements;
}
}
return false;
Expand Down Expand Up @@ -2587,31 +2587,6 @@ m_Not(const ValTy &V) {
return m_c_Xor(m_AllOnes(), V);
}

template <typename ValTy> struct NotForbidUndef_match {
ValTy Val;
NotForbidUndef_match(const ValTy &V) : Val(V) {}

template <typename OpTy> bool match(OpTy *V) {
// We do not use m_c_Xor because that could match an arbitrary APInt that is
// not -1 as C and then fail to match the other operand if it is -1.
// This code should still work even when both operands are constants.
Value *X;
const APInt *C;
if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes())
return Val.match(X);
if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes())
return Val.match(X);
return false;
}
};

/// Matches a bitwise 'not' as 'xor V, -1' or 'xor -1, V'. For vectors, the
/// constant value must be composed of only -1 scalar elements.
template <typename ValTy>
inline NotForbidUndef_match<ValTy> m_NotForbidUndef(const ValTy &V) {
return NotForbidUndef_match<ValTy>(V);
}

/// Matches an SMin with LHS and RHS in either order.
template <typename LHS, typename RHS>
inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>
Expand Down
23 changes: 10 additions & 13 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Expand Up @@ -1513,7 +1513,7 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,

// -1 >>a X --> -1
// (-1 << X) a>> X --> -1
// Do not return Op0 because it may contain undef elements if it's a vector.
// We could return the original -1 constant to preserve poison elements.
if (match(Op0, m_AllOnes()) ||
match(Op0, m_Shl(m_AllOnes(), m_Specific(Op1))))
return Constant::getAllOnesValue(Op0->getType());
Expand Down Expand Up @@ -2281,7 +2281,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B ^ ~A) | (A & B) --> B ^ ~A
// (~A ^ B) | (B & A) --> ~A ^ B
// (B ^ ~A) | (B & A) --> B ^ ~A
if (match(X, m_c_Xor(m_NotForbidUndef(m_Value(A)), m_Value(B))) &&
if (match(X, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return X;

Expand All @@ -2298,31 +2298,29 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B & ~A) | ~(A | B) --> ~A
// (B & ~A) | ~(B | A) --> ~A
Value *NotA;
if (match(X,
m_c_And(m_CombineAnd(m_Value(NotA), m_NotForbidUndef(m_Value(A))),
m_Value(B))) &&
if (match(X, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))),
m_Value(B))) &&
match(Y, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return NotA;
// The same is true of Logical And
// TODO: This could share the logic of the version above if there was a
// version of LogicalAnd that allowed more than just i1 types.
if (match(X, m_c_LogicalAnd(
m_CombineAnd(m_Value(NotA), m_NotForbidUndef(m_Value(A))),
m_Value(B))) &&
if (match(X, m_c_LogicalAnd(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))),
m_Value(B))) &&
match(Y, m_Not(m_c_LogicalOr(m_Specific(A), m_Specific(B)))))
return NotA;

// ~(A ^ B) | (A & B) --> ~(A ^ B)
// ~(A ^ B) | (B & A) --> ~(A ^ B)
Value *NotAB;
if (match(X, m_CombineAnd(m_NotForbidUndef(m_Xor(m_Value(A), m_Value(B))),
if (match(X, m_CombineAnd(m_Not(m_Xor(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return NotAB;

// ~(A & B) | (A ^ B) --> ~(A & B)
// ~(A & B) | (B ^ A) --> ~(A & B)
if (match(X, m_CombineAnd(m_NotForbidUndef(m_And(m_Value(A), m_Value(B))),
if (match(X, m_CombineAnd(m_Not(m_And(m_Value(A), m_Value(B))),
m_Value(NotAB))) &&
match(Y, m_c_Xor(m_Specific(A), m_Specific(B))))
return NotAB;
Expand Down Expand Up @@ -2552,9 +2550,8 @@ static Value *simplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// The 'not' op must contain a complete -1 operand (no undef elements for
// vector) for the transform to be safe.
Value *NotA;
if (match(X,
m_c_Or(m_CombineAnd(m_NotForbidUndef(m_Value(A)), m_Value(NotA)),
m_Value(B))) &&
if (match(X, m_c_Or(m_CombineAnd(m_Not(m_Value(A)), m_Value(NotA)),
m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return NotA;

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Constants.cpp
Expand Up @@ -316,7 +316,7 @@ bool Constant::isElementWiseEqual(Value *Y) const {
Constant *C0 = ConstantExpr::getBitCast(const_cast<Constant *>(this), IntTy);
Constant *C1 = ConstantExpr::getBitCast(cast<Constant>(Y), IntTy);
Constant *CmpEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, C0, C1);
return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
return isa<PoisonValue>(CmpEq) || match(CmpEq, m_One());
}

static bool
Expand Down
12 changes: 3 additions & 9 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Expand Up @@ -2538,6 +2538,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
}
}

// and(shl(zext(X), Y), SignMask) -> and(sext(X), SignMask)
// where Y is a valid shift amount.
if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))),
m_SignMask())) &&
match(Y, m_SpecificInt_ICMP(
Expand All @@ -2546,15 +2548,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
Ty->getScalarSizeInBits() -
X->getType()->getScalarSizeInBits())))) {
auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext");
auto *SanitizedSignMask = cast<Constant>(Op1);
// We must be careful with the undef elements of the sign bit mask, however:
// the mask elt can be undef iff the shift amount for that lane was undef,
// otherwise we need to sanitize undef masks to zero.
SanitizedSignMask = Constant::replaceUndefsWith(
SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType()));
SanitizedSignMask =
Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y));
return BinaryOperator::CreateAnd(SExt, SanitizedSignMask);
return BinaryOperator::CreateAnd(SExt, Op1);
}

if (Instruction *Z = narrowMaskedBinOp(I))
Expand Down
30 changes: 15 additions & 15 deletions llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
Expand Up @@ -2032,23 +2032,23 @@ define <4 x i64> @avx2_psrlv_q_256_allbig(<4 x i64> %v) {
ret <4 x i64> %1
}

; The shift amount is 0 (the undef lane could be 0), so we return the unshifted input.
; The shift amount is 0 (the poison lane could be 0), so we return the unshifted input.

define <2 x i64> @avx2_psrlv_q_128_undef(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_128_undef(
define <2 x i64> @avx2_psrlv_q_128_poison(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_128_poison(
; CHECK-NEXT: ret <2 x i64> [[V:%.*]]
;
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 undef, i64 1
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 poison, i64 1
%2 = tail call <2 x i64> @llvm.x86.avx2.psrlv.q(<2 x i64> %v, <2 x i64> %1)
ret <2 x i64> %2
}

define <4 x i64> @avx2_psrlv_q_256_undef(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_256_undef(
; CHECK-NEXT: [[TMP1:%.*]] = lshr <4 x i64> [[V:%.*]], <i64 undef, i64 8, i64 16, i64 31>
define <4 x i64> @avx2_psrlv_q_256_poison(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psrlv_q_256_poison(
; CHECK-NEXT: [[TMP1:%.*]] = lshr <4 x i64> [[V:%.*]], <i64 poison, i64 8, i64 16, i64 31>
; CHECK-NEXT: ret <4 x i64> [[TMP1]]
;
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 undef, i64 0
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 poison, i64 0
%2 = tail call <4 x i64> @llvm.x86.avx2.psrlv.q.256(<4 x i64> %v, <4 x i64> %1)
ret <4 x i64> %2
}
Expand Down Expand Up @@ -2435,21 +2435,21 @@ define <4 x i64> @avx2_psllv_q_256_allbig(<4 x i64> %v) {

; The shift amount is 0 (the undef lane could be 0), so we return the unshifted input.

define <2 x i64> @avx2_psllv_q_128_undef(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_128_undef(
define <2 x i64> @avx2_psllv_q_128_poison(<2 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_128_poison(
; CHECK-NEXT: ret <2 x i64> [[V:%.*]]
;
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 undef, i64 1
%1 = insertelement <2 x i64> <i64 0, i64 8>, i64 poison, i64 1
%2 = tail call <2 x i64> @llvm.x86.avx2.psllv.q(<2 x i64> %v, <2 x i64> %1)
ret <2 x i64> %2
}

define <4 x i64> @avx2_psllv_q_256_undef(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_256_undef(
; CHECK-NEXT: [[TMP1:%.*]] = shl <4 x i64> [[V:%.*]], <i64 undef, i64 8, i64 16, i64 31>
define <4 x i64> @avx2_psllv_q_256_poison(<4 x i64> %v) {
; CHECK-LABEL: @avx2_psllv_q_256_poison(
; CHECK-NEXT: [[TMP1:%.*]] = shl <4 x i64> [[V:%.*]], <i64 poison, i64 8, i64 16, i64 31>
; CHECK-NEXT: ret <4 x i64> [[TMP1]]
;
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 undef, i64 0
%1 = insertelement <4 x i64> <i64 0, i64 8, i64 16, i64 31>, i64 poison, i64 0
%2 = tail call <4 x i64> @llvm.x86.avx2.psllv.q.256(<4 x i64> %v, <4 x i64> %1)
ret <4 x i64> %2
}
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/Transforms/InstCombine/abs-1.ll
Expand Up @@ -63,14 +63,14 @@ define <2 x i8> @abs_canonical_2(<2 x i8> %x) {
ret <2 x i8> %abs
}

; Even if a constant has undef elements.
; Even if a constant has poison elements.

define <2 x i8> @abs_canonical_2_vec_undef_elts(<2 x i8> %x) {
; CHECK-LABEL: @abs_canonical_2_vec_undef_elts(
define <2 x i8> @abs_canonical_2_vec_poison_elts(<2 x i8> %x) {
; CHECK-LABEL: @abs_canonical_2_vec_poison_elts(
; CHECK-NEXT: [[ABS:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[X:%.*]], i1 false)
; CHECK-NEXT: ret <2 x i8> [[ABS]]
;
%cmp = icmp sgt <2 x i8> %x, <i8 undef, i8 -1>
%cmp = icmp sgt <2 x i8> %x, <i8 poison, i8 -1>
%neg = sub <2 x i8> zeroinitializer, %x
%abs = select <2 x i1> %cmp, <2 x i8> %x, <2 x i8> %neg
ret <2 x i8> %abs
Expand Down Expand Up @@ -208,15 +208,15 @@ define <2 x i8> @nabs_canonical_2(<2 x i8> %x) {
ret <2 x i8> %abs
}

; Even if a constant has undef elements.
; Even if a constant has poison elements.

define <2 x i8> @nabs_canonical_2_vec_undef_elts(<2 x i8> %x) {
; CHECK-LABEL: @nabs_canonical_2_vec_undef_elts(
define <2 x i8> @nabs_canonical_2_vec_poison_elts(<2 x i8> %x) {
; CHECK-LABEL: @nabs_canonical_2_vec_poison_elts(
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i8> @llvm.abs.v2i8(<2 x i8> [[X:%.*]], i1 false)
; CHECK-NEXT: [[ABS:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]]
; CHECK-NEXT: ret <2 x i8> [[ABS]]
;
%cmp = icmp sgt <2 x i8> %x, <i8 -1, i8 undef>
%cmp = icmp sgt <2 x i8> %x, <i8 -1, i8 poison>
%neg = sub <2 x i8> zeroinitializer, %x
%abs = select <2 x i1> %cmp, <2 x i8> %neg, <2 x i8> %x
ret <2 x i8> %abs
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/add-mask-neg.ll
Expand Up @@ -89,16 +89,16 @@ define <2 x i32> @dec_mask_neg_v2i32(<2 x i32> %X) {
ret <2 x i32> %dec
}

define <2 x i32> @dec_mask_neg_v2i32_undef(<2 x i32> %X) {
; CHECK-LABEL: @dec_mask_neg_v2i32_undef(
define <2 x i32> @dec_mask_neg_v2i32_poison(<2 x i32> %X) {
; CHECK-LABEL: @dec_mask_neg_v2i32_poison(
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 -1, i32 -1>
; CHECK-NEXT: [[TMP2:%.*]] = xor <2 x i32> [[X]], <i32 -1, i32 -1>
; CHECK-NEXT: [[DEC:%.*]] = and <2 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret <2 x i32> [[DEC]]
;
%neg = sub <2 x i32> zeroinitializer, %X
%mask = and <2 x i32> %neg, %X
%dec = add <2 x i32> %mask, <i32 -1, i32 undef>
%dec = add <2 x i32> %mask, <i32 -1, i32 poison>
ret <2 x i32> %dec
}

Expand Down
28 changes: 14 additions & 14 deletions llvm/test/Transforms/InstCombine/add.ll
Expand Up @@ -150,24 +150,24 @@ define i32 @test5_add_nsw(i32 %A, i32 %B) {
ret i32 %D
}

define <2 x i8> @neg_op0_vec_undef_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_op0_vec_undef_elt(
define <2 x i8> @neg_op0_vec_poison_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_op0_vec_poison_elt(
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> [[B:%.*]], [[A:%.*]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%nega = sub <2 x i8> <i8 0, i8 undef>, %a
%nega = sub <2 x i8> <i8 0, i8 poison>, %a
%r = add <2 x i8> %nega, %b
ret <2 x i8> %r
}

define <2 x i8> @neg_neg_vec_undef_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_neg_vec_undef_elt(
define <2 x i8> @neg_neg_vec_poison_elt(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @neg_neg_vec_poison_elt(
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i8> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[R:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%nega = sub <2 x i8> <i8 undef, i8 0>, %a
%negb = sub <2 x i8> <i8 undef, i8 0>, %b
%nega = sub <2 x i8> <i8 poison, i8 0>, %a
%negb = sub <2 x i8> <i8 poison, i8 0>, %b
%r = add <2 x i8> %nega, %negb
ret <2 x i8> %r
}
Expand Down Expand Up @@ -1196,14 +1196,14 @@ define <2 x i32> @test44_vec_non_matching(<2 x i32> %A) {
ret <2 x i32> %C
}

define <2 x i32> @test44_vec_undef(<2 x i32> %A) {
; CHECK-LABEL: @test44_vec_undef(
; CHECK-NEXT: [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 undef>
; CHECK-NEXT: [[C:%.*]] = add <2 x i32> [[B]], <i32 -123, i32 undef>
define <2 x i32> @test44_vec_poison(<2 x i32> %A) {
; CHECK-LABEL: @test44_vec_poison(
; CHECK-NEXT: [[B:%.*]] = or <2 x i32> [[A:%.*]], <i32 123, i32 poison>
; CHECK-NEXT: [[C:%.*]] = add nsw <2 x i32> [[B]], <i32 -123, i32 poison>
; CHECK-NEXT: ret <2 x i32> [[C]]
;
%B = or <2 x i32> %A, <i32 123, i32 undef>
%C = add <2 x i32> %B, <i32 -123, i32 undef>
%B = or <2 x i32> %A, <i32 123, i32 poison>
%C = add <2 x i32> %B, <i32 -123, i32 poison>
ret <2 x i32> %C
}

Expand Down Expand Up @@ -2983,7 +2983,7 @@ define i8 @signum_i8_i8_use3(i8 %x) {
ret i8 %r
}

; poison/undef is ok to propagate in shift amount
; poison is ok to propagate in shift amount
; complexity canonicalization guarantees that shift is op0 of add

define <2 x i5> @signum_v2i5_v2i5(<2 x i5> %x) {
Expand Down

0 comments on commit d9a5aa8

Please sign in to comment.