diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 90e4b5d71be608..da33d24d972899 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4684,12 +4684,32 @@ bool SelectionDAG::isEqualTo(SDValue A, SDValue B) const { return false; } +// Only bits set in Mask must be negated, other bits may be arbitrary. +static SDValue getBitwiseNotOperand(SDValue V, SDValue Mask) { + if (isBitwiseNot(V, true)) + return V.getOperand(0); + + // Handle any_extend (not (truncate X)) pattern, where Mask only sets + // bits in the non-extended part. + ConstantSDNode *MaskC = isConstOrConstSplat(Mask); + if (!MaskC || V.getOpcode() != ISD::ANY_EXTEND) + return SDValue(); + SDValue ExtArg = V.getOperand(0); + if (ExtArg.getScalarValueSizeInBits() >= + MaskC->getAPIntValue().getActiveBits() && + isBitwiseNot(ExtArg, true) && + ExtArg.getOperand(0).getOpcode() == ISD::TRUNCATE && + ExtArg.getOperand(0).getOperand(0).getValueType() == V.getValueType()) + return ExtArg.getOperand(0).getOperand(0); + return SDValue(); +} + static bool haveNoCommonBitsSetCommutative(SDValue A, SDValue B) { // Match masked merge pattern (X & ~M) op (Y & M) // Including degenerate case (X & ~M) op M - auto MatchNoCommonBitsPattern = [&](SDValue NotM, SDValue Other) { - if (isBitwiseNot(NotM, true)) { - SDValue NotOperand = NotM->getOperand(0); + auto MatchNoCommonBitsPattern = [&](SDValue Not, SDValue Mask, + SDValue Other) { + if (SDValue NotOperand = getBitwiseNotOperand(Not, Mask)) { if (Other == NotOperand) return true; if (Other->getOpcode() == ISD::AND) @@ -4699,8 +4719,8 @@ static bool haveNoCommonBitsSetCommutative(SDValue A, SDValue B) { return false; }; if (A->getOpcode() == ISD::AND) - return MatchNoCommonBitsPattern(A->getOperand(0), B) || - MatchNoCommonBitsPattern(A->getOperand(1), B); + return MatchNoCommonBitsPattern(A->getOperand(0), A->getOperand(1), B) || + MatchNoCommonBitsPattern(A->getOperand(1), A->getOperand(0), B); return false; } diff --git a/llvm/test/CodeGen/X86/add-and-not.ll b/llvm/test/CodeGen/X86/add-and-not.ll index bf8e507d2f82f8..c0434b1a5b2961 100644 --- a/llvm/test/CodeGen/X86/add-and-not.ll +++ b/llvm/test/CodeGen/X86/add-and-not.ll @@ -121,7 +121,7 @@ define i64 @add_and_xor_const(i64 %x) { ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %xor = xor i64 %x, -1 %and = and i64 %xor, 1 @@ -148,7 +148,7 @@ define i64 @add_and_xor_const_explicit_trunc(i64 %x) { ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %trunc = trunc i64 %x to i32 %xor = xor i32 %trunc, -1 @@ -195,7 +195,7 @@ define i8* @gep_and_xor_const(i8* %a) { ; CHECK-NEXT: movl %edi, %eax ; CHECK-NEXT: notl %eax ; CHECK-NEXT: andl $1, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: orq %rdi, %rax ; CHECK-NEXT: retq %old = ptrtoint i8* %a to i64 %old.not = and i64 %old, 1