From 1436a9232b10487a097f62bf85025fc6b6b66fde Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 4 Jan 2023 01:49:17 +0000 Subject: [PATCH] [LVI] Look through negations when evaluating conditions This teaches LVI (and thus CVP) to extract range information from branches whose condition is negated using (`xor %c, true`). On the implementation side, we switch the cache to additionally track whether we're looking for the inverted value or not and otherwise using the existing support for computing inverted conditions. I think the biggest question here is why this negation shows up here at all. After all, it should always be possible for some other pass to fold such a negation into a branch, comparison or some other logical operation. Indeed, instcombine does just that. However, these negations can be otherwise fairly persistent, e.g. instsimplify is not able to exchange branch conditions from negations. In addition, jumpthreading, which sits at the same point in default pass pipeline also handles this pattern, which adds further evidence that we might expect these negations to not have been canonicalized away yet at this point in the pass pipeline. In the particular case I was looking at there was a bit of a circular dependency where flags computed by cvp were needed by instcombine, and incstombine's folding of the negation was needed for cvp. Adding a second instombine pass would have worked of course, but instcombine can be somewhat expensive, so it appeared desirable to not require it to have run before cvp (as is the case in the default pass pipeline). Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D140933 --- llvm/lib/Analysis/LazyValueInfo.cpp | 51 +++++++++++------ .../CorrelatedValuePropagation/basic.ll | 57 +++++++++++++++++++ 2 files changed, 91 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 2a655af9dde9c..1832c847da45d 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -1166,11 +1166,16 @@ static ValueLatticeElement getValueFromOverflowCondition( return ValueLatticeElement::getRange(NWR); } -static std::optional -getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, - bool isRevisit, - SmallDenseMap &Visited, - SmallVectorImpl &Worklist) { +// Tracks a Value * condition and whether we're interested in it or its inverse +typedef PointerIntPair CondValue; + +static std::optional getValueFromConditionImpl( + Value *Val, CondValue CondVal, bool isRevisit, + SmallDenseMap &Visited, + SmallVectorImpl &Worklist) { + + Value *Cond = CondVal.getPointer(); + bool isTrueDest = CondVal.getInt(); if (!isRevisit) { if (ICmpInst *ICI = dyn_cast(Cond)) return getValueFromICmpCondition(Val, ICI, isTrueDest); @@ -1181,6 +1186,17 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, return getValueFromOverflowCondition(Val, WO, isTrueDest); } + Value *N; + if (match(Cond, m_Not(m_Value(N)))) { + CondValue NKey(N, !isTrueDest); + auto NV = Visited.find(NKey); + if (NV == Visited.end()) { + Worklist.push_back(NKey); + return std::nullopt; + } + return NV->second; + } + Value *L, *R; bool IsAnd; if (match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))) @@ -1190,13 +1206,13 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, else return ValueLatticeElement::getOverdefined(); - auto LV = Visited.find(L); - auto RV = Visited.find(R); + auto LV = Visited.find(CondValue(L, isTrueDest)); + auto RV = Visited.find(CondValue(R, isTrueDest)); // if (L && R) -> intersect L and R - // if (!(L || R)) -> intersect L and R + // if (!(L || R)) -> intersect !L and !R // if (L || R) -> union L and R - // if (!(L && R)) -> union L and R + // if (!(L && R)) -> union !L and !R if ((isTrueDest ^ IsAnd) && (LV != Visited.end())) { ValueLatticeElement V = LV->second; if (V.isOverdefined()) @@ -1210,9 +1226,9 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, if (LV == Visited.end() || RV == Visited.end()) { assert(!isRevisit); if (LV == Visited.end()) - Worklist.push_back(L); + Worklist.push_back(CondValue(L, isTrueDest)); if (RV == Visited.end()) - Worklist.push_back(R); + Worklist.push_back(CondValue(R, isTrueDest)); return std::nullopt; } @@ -1222,12 +1238,13 @@ getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest) { assert(Cond && "precondition"); - SmallDenseMap Visited; - SmallVector Worklist; + SmallDenseMap Visited; + SmallVector Worklist; - Worklist.push_back(Cond); + CondValue CondKey(Cond, isTrueDest); + Worklist.push_back(CondKey); do { - Value *CurrentCond = Worklist.back(); + CondValue CurrentCond = Worklist.back(); // Insert an Overdefined placeholder into the set to prevent // infinite recursion if there exists IRs that use not // dominated by its def as in this example: @@ -1237,14 +1254,14 @@ ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, Visited.try_emplace(CurrentCond, ValueLatticeElement::getOverdefined()); bool isRevisit = !Iter.second; std::optional Result = getValueFromConditionImpl( - Val, CurrentCond, isTrueDest, isRevisit, Visited, Worklist); + Val, CurrentCond, isRevisit, Visited, Worklist); if (Result) { Visited[CurrentCond] = *Result; Worklist.pop_back(); } } while (!Worklist.empty()); - auto Result = Visited.find(Cond); + auto Result = Visited.find(CondKey); assert(Result != Visited.end()); return Result->second; } diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll b/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll index b7f1c80565f83..c3c753375f1b4 100644 --- a/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/basic.ll @@ -1853,6 +1853,63 @@ define void @xor(i8 %a, ptr %p) { ret void } +define i1 @xor_neg_cond(i32 %a) { +; CHECK-LABEL: @xor_neg_cond( +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[A:%.*]], 10 +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[CMP1]], true +; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]] +; CHECK: guard: +; CHECK-NEXT: ret i1 true +; CHECK: exit: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp eq i32 %a, 10 + %xor = xor i1 %cmp1, true + br i1 %xor, label %exit, label %guard + +guard: + %cmp2 = icmp eq i32 %a, 10 + ret i1 %cmp2 + +exit: + ret i1 false +} + +define i1 @xor_approx(i32 %a) { +; CHECK-LABEL: @xor_approx( +; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i32 [[A:%.*]], 2 +; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[A]], 5 +; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[A]], 7 +; CHECK-NEXT: [[CMP4:%.*]] = icmp ult i32 [[A]], 9 +; CHECK-NEXT: [[AND1:%.*]] = and i1 [[CMP1]], [[CMP2]] +; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP3]], [[CMP4]] +; CHECK-NEXT: [[OR:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[XOR:%.*]] = xor i1 [[OR]], true +; CHECK-NEXT: br i1 [[XOR]], label [[EXIT:%.*]], label [[GUARD:%.*]] +; CHECK: guard: +; CHECK-NEXT: [[CMP5:%.*]] = icmp eq i32 [[A]], 6 +; CHECK-NEXT: ret i1 [[CMP5]] +; CHECK: exit: +; CHECK-NEXT: ret i1 false +; + %cmp1 = icmp ugt i32 %a, 2 + %cmp2 = icmp ult i32 %a, 5 + %cmp3 = icmp ugt i32 %a, 7 + %cmp4 = icmp ult i32 %a, 9 + %and1 = and i1 %cmp1, %cmp2 + %and2 = and i1 %cmp3, %cmp4 + %or = or i1 %and1, %and2 + %xor = xor i1 %or, true + br i1 %xor, label %exit, label %guard + +guard: + %cmp5 = icmp eq i32 %a, 6 + ret i1 %cmp5 + +exit: + ret i1 false +} + declare i32 @llvm.uadd.sat.i32(i32, i32) declare i32 @llvm.usub.sat.i32(i32, i32) declare i32 @llvm.sadd.sat.i32(i32, i32)