From da822ce90ed9c46ef8f0abc40f572865f117b384 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Sat, 15 Jul 2023 15:55:04 -0500 Subject: [PATCH] [InstCombine] Generalise ((x1 ^ y1) | (x2 ^ y2)) == 0 transform Generalise ((x1 ^ y1) | (x2 ^ y2)) == 0 transform to more than two pairs of variables https://github.com/llvm/llvm-project/issues/57831. Depends D154384. Reviewed By: goldstein.w.n, nikic Differential Revision: https://reviews.llvm.org/D154306 --- .../InstCombine/InstCombineCompares.cpp | 67 +++++++++++++++---- llvm/test/Transforms/InstCombine/icmp-or.ll | 66 +++++++++--------- 2 files changed, 85 insertions(+), 48 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index a70877549ce7f..d956d99fb2ba8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1945,6 +1945,59 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return nullptr; } +/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0. +static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, + InstCombiner::BuilderTy &Builder) { + // Are we using xors to bitwise check for a pair or pairs of (in)equalities? + // Convert to a shorter form that has more potential to be folded even + // further. + // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 --> + // (X1 == X2) && (X3 == X4) && (X5 == X6) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 --> + // (X1 != X2) || (X3 != X4) || (X5 != X6) + // TODO: Implement for sub + SmallVector, 2> CmpValues; + SmallVector WorkList(1, Or); + + while (!WorkList.empty()) { + auto MatchOrOperatorArgument = [&](Value *OrOperatorArgument) { + Value *Lhs, *Rhs; + + if (match(OrOperatorArgument, + m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { + CmpValues.emplace_back(Lhs, Rhs); + } else { + WorkList.push_back(OrOperatorArgument); + } + }; + + Value *CurrentValue = WorkList.pop_back_val(); + Value *OrOperatorLhs, *OrOperatorRhs; + + if (!match(CurrentValue, + m_Or(m_Value(OrOperatorLhs), m_Value(OrOperatorRhs)))) { + return nullptr; + } + + MatchOrOperatorArgument(OrOperatorRhs); + MatchOrOperatorArgument(OrOperatorLhs); + } + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + Value *LhsCmp = Builder.CreateICmp(Pred, CmpValues.rbegin()->first, + CmpValues.rbegin()->second); + + for (auto It = CmpValues.rbegin() + 1; It != CmpValues.rend(); ++It) { + Value *RhsCmp = Builder.CreateICmp(Pred, It->first, It->second); + LhsCmp = Builder.CreateBinOp(BOpc, LhsCmp, RhsCmp); + } + + return LhsCmp; +} + /// Fold icmp (or X, Y), C. Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, @@ -2030,18 +2083,8 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return BinaryOperator::Create(BOpc, CmpP, CmpQ); } - // Are we using xors to bitwise check for a pair of (in)equalities? Convert to - // a shorter form that has more potential to be folded even further. - Value *X1, *X2, *X3, *X4; - if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && - match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { - // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) - // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) - Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); - Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4); - auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; - return BinaryOperator::Create(BOpc, Cmp12, Cmp34); - } + if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder)) + return replaceInstUsesWith(Cmp, V); return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/icmp-or.ll b/llvm/test/Transforms/InstCombine/icmp-or.ll index 820aeb0ac2ff0..9d35b5c916969 100644 --- a/llvm/test/Transforms/InstCombine/icmp-or.ll +++ b/llvm/test/Transforms/InstCombine/icmp-or.ll @@ -366,10 +366,9 @@ define i1 @not_decrement_sgt_n1(i8 %x) { define i1 @icmp_or_xor_2_eq(i64 %x1, i64 %y1, i64 %x2, i64 %y2) { ; CHECK-LABEL: @icmp_or_xor_2_eq( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP1]], [[TMP2]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1 @@ -381,10 +380,9 @@ define i1 @icmp_or_xor_2_eq(i64 %x1, i64 %y1, i64 %x2, i64 %y2) { define i1 @icmp_or_xor_2_ne(i64 %x1, i64 %y1, i64 %x2, i64 %y2) { ; CHECK-LABEL: @icmp_or_xor_2_ne( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[OR]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = or i1 [[TMP1]], [[TMP2]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1 @@ -472,12 +470,11 @@ define i1 @icmp_or_xor_2_4_fail(i64 %x1, i64 %y1, i64 %x2, i64 %y2) { define i1 @icmp_or_xor_3_1(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 %y3) { ; CHECK-LABEL: @icmp_or_xor_3_1( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[X3:%.*]], [[Y3:%.*]] -; CHECK-NEXT: [[OR1:%.*]] = or i64 [[OR]], [[XOR2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR1]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = and i1 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i64 [[X3:%.*]], [[Y3:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP3]], [[TMP4]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1 @@ -512,12 +509,11 @@ define i1 @icmp_or_xor_3_fail(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 % define i1 @icmp_or_xor_3_3(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 %y3) { ; CHECK-LABEL: @icmp_or_xor_3_3( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[X3:%.*]], [[Y3:%.*]] -; CHECK-NEXT: [[OR1:%.*]] = or i64 [[XOR2]], [[OR]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR1]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = and i1 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i64 [[X3:%.*]], [[Y3:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP3]], [[TMP4]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1 @@ -552,14 +548,13 @@ define i1 @icmp_or_xor_3_4_fail(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 define i1 @icmp_or_xor_4_1(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 %y3, i64 %x4, i64 %y4) { ; CHECK-LABEL: @icmp_or_xor_4_1( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[X3:%.*]], [[Y3:%.*]] -; CHECK-NEXT: [[XOR3:%.*]] = xor i64 [[X4:%.*]], [[Y4:%.*]] -; CHECK-NEXT: [[OR1:%.*]] = or i64 [[XOR2]], [[XOR3]] -; CHECK-NEXT: [[OR2:%.*]] = or i64 [[OR]], [[OR1]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR2]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X3:%.*]], [[Y3:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X4:%.*]], [[Y4:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = and i1 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP5]], [[TMP6]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1 @@ -575,14 +570,13 @@ define i1 @icmp_or_xor_4_1(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 %y3, define i1 @icmp_or_xor_4_2(i64 %x1, i64 %y1, i64 %x2, i64 %y2, i64 %x3, i64 %y3, i64 %x4, i64 %y4) { ; CHECK-LABEL: @icmp_or_xor_4_2( -; CHECK-NEXT: [[XOR:%.*]] = xor i64 [[X1:%.*]], [[Y1:%.*]] -; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[X2:%.*]], [[Y2:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[XOR]], [[XOR1]] -; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[X3:%.*]], [[Y3:%.*]] -; CHECK-NEXT: [[XOR3:%.*]] = xor i64 [[X4:%.*]], [[Y4:%.*]] -; CHECK-NEXT: [[OR1:%.*]] = or i64 [[XOR2]], [[XOR3]] -; CHECK-NEXT: [[OR2:%.*]] = or i64 [[OR1]], [[OR]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR2]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X1:%.*]], [[Y1:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[X2:%.*]], [[Y2:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = and i1 [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i64 [[X3:%.*]], [[Y3:%.*]] +; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[X4:%.*]], [[Y4:%.*]] +; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP5]], [[TMP6]] ; CHECK-NEXT: ret i1 [[CMP]] ; %xor = xor i64 %x1, %y1