diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 9c9b3f4dc89928..b34ba4e7908f3b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2283,6 +2283,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); + auto *LHSC = dyn_cast(LHS1); + auto *RHSC = dyn_cast(RHS1); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -2294,43 +2296,42 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 3) C1 ^ C2 is one-bit mask. // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. - const APInt *LHSVal, *RHSVal; if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && - PredL == PredR && LHS->getType() == RHS->getType() && - LHS->getType()->isIntOrIntVectorTy() && match(LHS1, m_APInt(LHSVal)) && - match(RHS1, m_APInt(RHSVal)) && *LHSVal == *RHSVal && LHS->hasOneUse() && - RHS->hasOneUse()) { - Value *AddOpnd; - const APInt *LAddVal, *RAddVal; - if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddVal))) && - match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddVal))) && - LAddVal->ugt(*LHSVal) && RAddVal->ugt(*LHSVal)) { - - APInt DiffC = *LAddVal ^ *RAddVal; - if (DiffC.isPowerOf2()) { - const APInt *MaxAddC = nullptr; - if (LAddVal->ult(*RAddVal)) - MaxAddC = RAddVal; + PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && + LHSC->getType() == RHSC->getType() && + LHSC->getValue() == (RHSC->getValue())) { + + Value *LAddOpnd, *RAddOpnd; + ConstantInt *LAddC, *RAddC; + if (match(LHS0, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && + match(RHS0, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && + LAddC->getValue().ugt(LHSC->getValue()) && + RAddC->getValue().ugt(LHSC->getValue())) { + + APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { + ConstantInt *MaxAddC = nullptr; + if (LAddC->getValue().ult(RAddC->getValue())) + MaxAddC = RAddC; else - MaxAddC = LAddVal; + MaxAddC = LAddC; - APInt RRangeLow = -*RAddVal; - APInt RRangeHigh = RRangeLow + *LHSVal; - APInt LRangeLow = -*LAddVal; - APInt LRangeHigh = LRangeLow + *LHSVal; + APInt RRangeLow = -RAddC->getValue(); + APInt RRangeHigh = RRangeLow + LHSC->getValue(); + APInt LRangeLow = -LAddC->getValue(); + APInt LRangeHigh = LRangeLow + LHSC->getValue(); APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(*LHSVal)) { - Value *NewAnd = Builder.CreateAnd( - AddOpnd, ConstantInt::get(LHS0->getType(), ~DiffC)); - Value *NewAdd = Builder.CreateAdd( - NewAnd, ConstantInt::get(LHS0->getType(), *MaxAddC)); - return Builder.CreateICmp(LHS->getPredicate(), NewAdd, - ConstantInt::get(LHS0->getType(), *LHSVal)); + RangeDiff.ugt(LHSC->getValue())) { + Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); + + Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC); + Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); + return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); } } } @@ -2416,8 +2417,6 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, } // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). - auto *LHSC = dyn_cast(LHS1); - auto *RHSC = dyn_cast(RHS1); if (!LHSC || !RHSC) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll index 7e4115cc8934d9..b5e3af2c765254 100644 --- a/llvm/test/Transforms/InstCombine/or.ll +++ b/llvm/test/Transforms/InstCombine/or.ll @@ -650,10 +650,12 @@ define i1 @test46(i8 signext %c) { define <2 x i1> @test46_uniform(<2 x i8> %c) { ; CHECK-LABEL: @test46_uniform( -; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], -; CHECK-NEXT: ret <2 x i1> [[TMP3]] +; CHECK-NEXT: [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], +; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], +; CHECK-NEXT: [[C_OFF17:%.*]] = add <2 x i8> [[C]], +; CHECK-NEXT: [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], +; CHECK-NEXT: [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]] +; CHECK-NEXT: ret <2 x i1> [[OR]] ; %c.off = add <2 x i8> %c, %cmp1 = icmp ult <2 x i8> %c.off,