Skip to content

Commit

Permalink
Generalize getInvertibleOperand recurrence handling slightly
Browse files Browse the repository at this point in the history
Follow up to D99912, specifically the revert, fix, and reapply thereof.

This generalizes the invertible recurrence logic in two ways:
* By allowing mismatching operand numbers of the phi, we can recurse through a pair of phi recurrences whose operand orders have not been canonicalized.
* By allowing recurrences through operand 1, we can invert these odd (but legal) recurrence.

Differential Revision: https://reviews.llvm.org/D100884
  • Loading branch information
preames committed Apr 28, 2021
1 parent 29cb9dc commit 0c01b37
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 38 deletions.
57 changes: 27 additions & 30 deletions llvm/lib/Analysis/ValueTracking.cpp
Expand Up @@ -2513,26 +2513,31 @@ bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) {
return isKnownNonZero(V, DemandedElts, Depth, Q);
}

/// If the pair of operators are the same invertible function of a single
/// operand return the index of that operand. Otherwise, return None. An
/// invertible function is one that is 1-to-1 and maps every input value
/// to exactly one output value. This is equivalent to saying that Op1
/// and Op2 are equal exactly when the specified pair of operands are equal,
/// (except that Op1 and Op2 may be poison more often.)
static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
const Operator *Op2) {
/// If the pair of operators are the same invertible function, return the
/// the operands of the function corresponding to each input. Otherwise,
/// return None. An invertible function is one that is 1-to-1 and maps
/// every input value to exactly one output value. This is equivalent to
/// saying that Op1 and Op2 are equal exactly when the specified pair of
/// operands are equal, (except that Op1 and Op2 may be poison more often.)
static Optional<std::pair<Value*, Value*>>
getInvertibleOperands(const Operator *Op1,
const Operator *Op2) {
if (Op1->getOpcode() != Op2->getOpcode())
return None;

auto getOperands = [&](unsigned OpNum) -> auto {
return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum));
};

switch (Op1->getOpcode()) {
default:
break;
case Instruction::Add:
case Instruction::Sub:
if (Op1->getOperand(0) == Op2->getOperand(0))
return 1;
return getOperands(1);
if (Op1->getOperand(1) == Op2->getOperand(1))
return 0;
return getOperands(0);
break;
case Instruction::Mul: {
// invertible if A * B == (A * B) mod 2^N where A, and B are integers
Expand All @@ -2548,7 +2553,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
if (Op1->getOperand(1) == Op2->getOperand(1) &&
isa<ConstantInt>(Op1->getOperand(1)) &&
!cast<ConstantInt>(Op1->getOperand(1))->isZero())
return 0;
return getOperands(0);
break;
}
case Instruction::Shl: {
Expand All @@ -2561,7 +2566,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
break;

if (Op1->getOperand(1) == Op2->getOperand(1))
return 0;
return getOperands(0);
break;
}
case Instruction::AShr:
Expand All @@ -2572,13 +2577,13 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
break;

if (Op1->getOperand(1) == Op2->getOperand(1))
return 0;
return getOperands(0);
break;
}
case Instruction::SExt:
case Instruction::ZExt:
if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType())
return 0;
return getOperands(0);
break;
case Instruction::PHI: {
const PHINode *PN1 = cast<PHINode>(Op1);
Expand All @@ -2596,18 +2601,12 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
!matchSimpleRecurrence(PN2, BO2, Start2, Step2))
break;

Optional<unsigned> Idx = getInvertibleOperand(cast<Operator>(BO1),
cast<Operator>(BO2));
if (!Idx || *Idx != 0)
auto Values = getInvertibleOperands(cast<Operator>(BO1),
cast<Operator>(BO2));
if (!Values)
break;
assert(BO1->getOperand(*Idx) == PN1 && BO2->getOperand(*Idx) == PN2);

// Phi operands might not be in the same order. TODO: generalize
// interface to return pair of operands.
if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2)
return 1;
if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2)
return 0;
assert(Values->first == PN1 && Values->second == PN2);
return std::make_pair(Start1, Start2);
}
}
return None;
Expand Down Expand Up @@ -2704,11 +2703,9 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
auto *O1 = dyn_cast<Operator>(V1);
auto *O2 = dyn_cast<Operator>(V2);
if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
if (Optional<unsigned> Opt = getInvertibleOperand(O1, O2)) {
unsigned Idx = *Opt;
return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx),
Depth + 1, Q);
}
if (auto Values = getInvertibleOperands(O1, O2))
return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);

if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
const PHINode *PN2 = cast<PHINode>(V2);
// FIXME: This is missing a generalization to handle the case where one is
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/Analysis/ValueTracking/known-non-equal.ll
Expand Up @@ -736,8 +736,7 @@ define i1 @recurrence_add_op_order(i8 %A) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
; CHECK: exit:
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 false
;
entry:
%B = add i8 %A, 1
Expand Down Expand Up @@ -808,8 +807,7 @@ define i1 @recurrence_add_phi_different_order1(i8 %A) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
; CHECK: exit:
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 false
;
entry:
%B = add i8 %A, 1
Expand Down Expand Up @@ -843,8 +841,7 @@ define i1 @recurrence_add_phi_different_order2(i8 %A) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
; CHECK: exit:
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 false
;
entry:
%B = add i8 %A, 1
Expand Down Expand Up @@ -979,8 +976,7 @@ define i1 @recurrence_sub_op_order(i8 %A) {
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
; CHECK: exit:
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 false
;
entry:
%B = add i8 %A, 1
Expand Down

0 comments on commit 0c01b37

Please sign in to comment.