Skip to content

Commit

Permalink
[LVI] Look through negations when evaluating conditions
Browse files Browse the repository at this point in the history
This teaches LVI (and thus CVP) to extract range information
from branches whose condition is negated using (`xor %c, true`).
On the implementation side, we switch the cache to additionally
track whether we're looking for the inverted value or not and
otherwise using the existing support for computing inverted
conditions.

I think the biggest question here is why this negation shows up
here at all. After all, it should always be possible for some
other pass to fold such a negation into a branch, comparison or
some other logical operation. Indeed, instcombine does just that.
However, these negations can be otherwise fairly persistent, e.g.
instsimplify is not able to exchange branch conditions from
negations. In addition, jumpthreading, which sits at the same
point in default pass pipeline also handles this pattern, which
adds further evidence that we might expect these negations to
not have been canonicalized away yet at this point in the pass
pipeline.

In the particular case I was looking at there was a bit of a
circular dependency where flags computed by cvp were needed
by instcombine, and incstombine's folding of the negation was
needed for cvp. Adding a second instombine pass would have
worked of course, but instcombine can be somewhat expensive,
so it appeared desirable to not require it to have run
before cvp (as is the case in the default pass pipeline).

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D140933
  • Loading branch information
Keno committed Jan 5, 2023
1 parent cf8fd21 commit 1436a92
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 17 deletions.
51 changes: 34 additions & 17 deletions llvm/lib/Analysis/LazyValueInfo.cpp
Expand Up @@ -1166,11 +1166,16 @@ static ValueLatticeElement getValueFromOverflowCondition(
return ValueLatticeElement::getRange(NWR);
}

static std::optional<ValueLatticeElement>
getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
bool isRevisit,
SmallDenseMap<Value *, ValueLatticeElement> &Visited,
SmallVectorImpl<Value *> &Worklist) {
// Tracks a Value * condition and whether we're interested in it or its inverse
typedef PointerIntPair<Value *, 1, bool> CondValue;

static std::optional<ValueLatticeElement> getValueFromConditionImpl(
Value *Val, CondValue CondVal, bool isRevisit,
SmallDenseMap<CondValue, ValueLatticeElement> &Visited,
SmallVectorImpl<CondValue> &Worklist) {

Value *Cond = CondVal.getPointer();
bool isTrueDest = CondVal.getInt();
if (!isRevisit) {
if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
return getValueFromICmpCondition(Val, ICI, isTrueDest);
Expand All @@ -1181,6 +1186,17 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
return getValueFromOverflowCondition(Val, WO, isTrueDest);
}

Value *N;
if (match(Cond, m_Not(m_Value(N)))) {
CondValue NKey(N, !isTrueDest);
auto NV = Visited.find(NKey);
if (NV == Visited.end()) {
Worklist.push_back(NKey);
return std::nullopt;
}
return NV->second;
}

Value *L, *R;
bool IsAnd;
if (match(Cond, m_LogicalAnd(m_Value(L), m_Value(R))))
Expand All @@ -1190,13 +1206,13 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
else
return ValueLatticeElement::getOverdefined();

auto LV = Visited.find(L);
auto RV = Visited.find(R);
auto LV = Visited.find(CondValue(L, isTrueDest));
auto RV = Visited.find(CondValue(R, isTrueDest));

// if (L && R) -> intersect L and R
// if (!(L || R)) -> intersect L and R
// if (!(L || R)) -> intersect !L and !R
// if (L || R) -> union L and R
// if (!(L && R)) -> union L and R
// if (!(L && R)) -> union !L and !R
if ((isTrueDest ^ IsAnd) && (LV != Visited.end())) {
ValueLatticeElement V = LV->second;
if (V.isOverdefined())
Expand All @@ -1210,9 +1226,9 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
if (LV == Visited.end() || RV == Visited.end()) {
assert(!isRevisit);
if (LV == Visited.end())
Worklist.push_back(L);
Worklist.push_back(CondValue(L, isTrueDest));
if (RV == Visited.end())
Worklist.push_back(R);
Worklist.push_back(CondValue(R, isTrueDest));
return std::nullopt;
}

Expand All @@ -1222,12 +1238,13 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest,
ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
bool isTrueDest) {
assert(Cond && "precondition");
SmallDenseMap<Value*, ValueLatticeElement> Visited;
SmallVector<Value *> Worklist;
SmallDenseMap<CondValue, ValueLatticeElement> Visited;
SmallVector<CondValue> Worklist;

Worklist.push_back(Cond);
CondValue CondKey(Cond, isTrueDest);
Worklist.push_back(CondKey);
do {
Value *CurrentCond = Worklist.back();
CondValue CurrentCond = Worklist.back();
// Insert an Overdefined placeholder into the set to prevent
// infinite recursion if there exists IRs that use not
// dominated by its def as in this example:
Expand All @@ -1237,14 +1254,14 @@ ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
Visited.try_emplace(CurrentCond, ValueLatticeElement::getOverdefined());
bool isRevisit = !Iter.second;
std::optional<ValueLatticeElement> Result = getValueFromConditionImpl(
Val, CurrentCond, isTrueDest, isRevisit, Visited, Worklist);
Val, CurrentCond, isRevisit, Visited, Worklist);
if (Result) {
Visited[CurrentCond] = *Result;
Worklist.pop_back();
}
} while (!Worklist.empty());

auto Result = Visited.find(Cond);
auto Result = Visited.find(CondKey);
assert(Result != Visited.end());
return Result->second;
}
Expand Down
57 changes: 57 additions & 0 deletions llvm/test/Transforms/CorrelatedValuePropagation/basic.ll
Expand Up @@ -1853,6 +1853,63 @@ define void @xor(i8 %a, ptr %p) {
ret void
}

define i1 @xor_neg_cond(i32 %a) {
; CHECK-LABEL: @xor_neg_cond(
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[A:%.*]], 10
; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]]
; CHECK: guard:
; CHECK-NEXT: ret i1 true
; CHECK: exit:
; CHECK-NEXT: ret i1 false
;
%cmp1 = icmp eq i32 %a, 10
%xor = xor i1 %cmp1, true
br i1 %xor, label %exit, label %guard

guard:
%cmp2 = icmp eq i32 %a, 10
ret i1 %cmp2

exit:
ret i1 false
}

define i1 @xor_approx(i32 %a) {
; CHECK-LABEL: @xor_approx(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i32 [[A:%.*]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[A]], 5
; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[A]], 7
; CHECK-NEXT: [[CMP4:%.*]] = icmp ult i32 [[A]], 9
; CHECK-NEXT: [[AND1:%.*]] = and i1 [[CMP1]], [[CMP2]]
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP3]], [[CMP4]]
; CHECK-NEXT: [[OR:%.*]] = or i1 [[AND1]], [[AND2]]
; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[OR]], true
; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]]
; CHECK: guard:
; CHECK-NEXT: [[CMP5:%.*]] = icmp eq i32 [[A]], 6
; CHECK-NEXT: ret i1 [[CMP5]]
; CHECK: exit:
; CHECK-NEXT: ret i1 false
;
%cmp1 = icmp ugt i32 %a, 2
%cmp2 = icmp ult i32 %a, 5
%cmp3 = icmp ugt i32 %a, 7
%cmp4 = icmp ult i32 %a, 9
%and1 = and i1 %cmp1, %cmp2
%and2 = and i1 %cmp3, %cmp4
%or = or i1 %and1, %and2
%xor = xor i1 %or, true
br i1 %xor, label %exit, label %guard

guard:
%cmp5 = icmp eq i32 %a, 6
ret i1 %cmp5

exit:
ret i1 false
}

declare i32 @llvm.uadd.sat.i32(i32, i32)
declare i32 @llvm.usub.sat.i32(i32, i32)
declare i32 @llvm.sadd.sat.i32(i32, i32)
Expand Down

0 comments on commit 1436a92

Please sign in to comment.