Skip to content

Commit

Permalink
[InstCombine] Support splat vectors in some or of icmp folds
Browse files Browse the repository at this point in the history
Replace m_ConstantInt() with m_APInt() in order to support splat
constants in addition to scalar integers.
  • Loading branch information
nikic committed Nov 10, 2021
1 parent 80072fd commit 0242a6a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 44 deletions.
45 changes: 23 additions & 22 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Expand Up @@ -2340,8 +2340,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
auto *LHSC = dyn_cast<ConstantInt>(LHS1);
auto *RHSC = dyn_cast<ConstantInt>(RHS1);
const APInt *LHSC = nullptr, *RHSC = nullptr;
match(LHS1, m_APInt(LHSC));
match(RHS1, m_APInt(RHSC));

// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
Expand All @@ -2355,40 +2356,41 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// This implies all values in the two ranges differ by exactly one bit.
if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
LHSC->getType() == RHSC->getType() &&
LHSC->getValue() == (RHSC->getValue())) {
LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) {

Value *AddOpnd;
ConstantInt *LAddC, *RAddC;
if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) &&
match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) &&
LAddC->getValue().ugt(LHSC->getValue()) &&
RAddC->getValue().ugt(LHSC->getValue())) {
const APInt *LAddC, *RAddC;
if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) &&
match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) &&
LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) {

APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
APInt DiffC = *LAddC ^ *RAddC;
if (DiffC.isPowerOf2()) {
ConstantInt *MaxAddC = nullptr;
if (LAddC->getValue().ult(RAddC->getValue()))
const APInt *MaxAddC = nullptr;
if (LAddC->ult(*RAddC))
MaxAddC = RAddC;
else
MaxAddC = LAddC;

APInt RRangeLow = -RAddC->getValue();
APInt RRangeHigh = RRangeLow + LHSC->getValue();
APInt LRangeLow = -LAddC->getValue();
APInt LRangeHigh = LRangeLow + LHSC->getValue();
APInt RRangeLow = -*RAddC;
APInt RRangeHigh = RRangeLow + *LHSC;
APInt LRangeLow = -*LAddC;
APInt LRangeHigh = LRangeLow + *LHSC;
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
: RRangeLow - LRangeLow;

if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
RangeDiff.ugt(LHSC->getValue())) {
Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
RangeDiff.ugt(*LHSC)) {
Type *Ty = AddOpnd->getType();
Value *MaskC = ConstantInt::get(Ty, ~DiffC);

Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC);
Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC);
return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC);
Value *NewAdd = Builder.CreateAdd(NewAnd,
ConstantInt::get(Ty, *MaxAddC));
return Builder.CreateICmp(LHS->getPredicate(), NewAdd,
ConstantInt::get(Ty, *LHSC));
}
}
}
Expand Down Expand Up @@ -2480,8 +2482,7 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (!LHSC || !RHSC)
return nullptr;

return foldAndOrOfICmpsUsingRanges(PredL, LHS0, LHSC->getValue(),
PredR, RHS0, RHSC->getValue(),
return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC,
Builder, /* IsAnd */ false);
}

Expand Down
38 changes: 16 additions & 22 deletions llvm/test/Transforms/InstCombine/or.ll
Expand Up @@ -121,13 +121,11 @@ define i1 @test18_logical(i32 %A) {
ret i1 %D
}

; FIXME: Vectors should fold too.
define <2 x i1> @test18vec(<2 x i32> %A) {
; CHECK-LABEL: @test18vec(
; CHECK-NEXT: [[B:%.*]] = icmp sgt <2 x i32> [[A:%.*]], <i32 99, i32 99>
; CHECK-NEXT: [[C:%.*]] = icmp slt <2 x i32> [[A]], <i32 50, i32 50>
; CHECK-NEXT: [[D:%.*]] = or <2 x i1> [[B]], [[C]]
; CHECK-NEXT: ret <2 x i1> [[D]]
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[A:%.*]], <i32 -100, i32 -100>
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 -50, i32 -50>
; CHECK-NEXT: ret <2 x i1> [[TMP2]]
;
%B = icmp sge <2 x i32> %A, <i32 100, i32 100>
%C = icmp slt <2 x i32> %A, <i32 50, i32 50>
Expand Down Expand Up @@ -481,9 +479,9 @@ define i1 @test36_logical(i32 %x) {

define i1 @test37(i32 %x) {
; CHECK-LABEL: @test37(
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[X:%.*]], 7
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
; CHECK-NEXT: ret i1 [[TMP1]]
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X:%.*]], 7
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
; CHECK-NEXT: ret i1 [[TMP2]]
;
%add1 = add i32 %x, 7
%cmp1 = icmp ult i32 %add1, 30
Expand All @@ -494,9 +492,9 @@ define i1 @test37(i32 %x) {

define i1 @test37_logical(i32 %x) {
; CHECK-LABEL: @test37_logical(
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[X:%.*]], 7
; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[ADD1]], 31
; CHECK-NEXT: ret i1 [[TMP1]]
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[X:%.*]], 7
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 31
; CHECK-NEXT: ret i1 [[TMP2]]
;
%add1 = add i32 %x, 7
%cmp1 = icmp ult i32 %add1, 30
Expand All @@ -507,11 +505,9 @@ define i1 @test37_logical(i32 %x) {

define <2 x i1> @test37_uniform(<2 x i32> %x) {
; CHECK-LABEL: @test37_uniform(
; CHECK-NEXT: [[ADD1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i32> [[ADD1]], <i32 30, i32 30>
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq <2 x i32> [[X]], <i32 23, i32 23>
; CHECK-NEXT: [[RET1:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
; CHECK-NEXT: ret <2 x i1> [[RET1]]
; CHECK-NEXT: [[TMP1:%.*]] = add <2 x i32> [[X:%.*]], <i32 7, i32 7>
; CHECK-NEXT: [[TMP2:%.*]] = icmp ult <2 x i32> [[TMP1]], <i32 31, i32 31>
; CHECK-NEXT: ret <2 x i1> [[TMP2]]
;
%add1 = add <2 x i32> %x, <i32 7, i32 7>
%cmp1 = icmp ult <2 x i32> %add1, <i32 30, i32 30>
Expand Down Expand Up @@ -792,12 +788,10 @@ define i1 @test46_logical(i8 signext %c) {

define <2 x i1> @test46_uniform(<2 x i8> %c) {
; CHECK-LABEL: @test46_uniform(
; CHECK-NEXT: [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], <i8 -97, i8 -97>
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], <i8 26, i8 26>
; CHECK-NEXT: [[C_OFF17:%.*]] = add <2 x i8> [[C]], <i8 -65, i8 -65>
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], <i8 26, i8 26>
; CHECK-NEXT: [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]]
; CHECK-NEXT: ret <2 x i1> [[OR]]
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], <i8 -33, i8 -33>
; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i8> [[TMP1]], <i8 -65, i8 -65>
; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], <i8 26, i8 26>
; CHECK-NEXT: ret <2 x i1> [[TMP3]]
;
%c.off = add <2 x i8> %c, <i8 -97, i8 -97>
%cmp1 = icmp ult <2 x i8> %c.off, <i8 26, i8 26>
Expand Down

0 comments on commit 0242a6a

Please sign in to comment.