101 changes: 54 additions & 47 deletions llvm/lib/Analysis/LazyValueInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,10 @@ void LazyValueInfoCache::threadEdgeImpl(BasicBlock *OldSucc,
}
}


namespace llvm {
namespace {
/// An assembly annotator class to print LazyValueCache information in
/// comments.
class LazyValueInfoImpl;
class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter {
LazyValueInfoImpl *LVIImpl;
// While analyzing which blocks we can solve values for, we need the dominator
Expand All @@ -357,8 +356,7 @@ class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter {
void emitInstructionAnnot(const Instruction *I,
formatted_raw_ostream &OS) override;
};
}
namespace {
} // namespace
// The actual implementation of the lazy analysis and update. Note that the
// inheritance from LazyValueInfoCache is intended to be temporary while
// splitting the code and then transitioning to a has-a relationship.
Expand Down Expand Up @@ -465,6 +463,10 @@ class LazyValueInfoImpl {
F.print(OS, &Writer);
}

/// This is part of the update interface to remove information related to this
/// value from the cache.
void forgetValue(Value *V) { TheCache.eraseValue(V); }

/// This is part of the update interface to inform the cache
/// that a block has been deleted.
void eraseBlock(BasicBlock *BB) {
Expand All @@ -479,8 +481,7 @@ class LazyValueInfoImpl {
Function *GuardDecl)
: AC(AC), DL(DL), GuardDecl(GuardDecl) {}
};
} // end anonymous namespace

} // namespace llvm

void LazyValueInfoImpl::solve() {
SmallVector<std::pair<BasicBlock *, Value *>, 8> StartingStack(
Expand Down Expand Up @@ -1542,25 +1543,12 @@ void LazyValueInfoImpl::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc,
// LazyValueInfo Impl
//===----------------------------------------------------------------------===//

/// This lazily constructs the LazyValueInfoImpl.
static LazyValueInfoImpl &getImpl(void *&PImpl, AssumptionCache *AC,
const Module *M) {
if (!PImpl) {
assert(M && "getCache() called with a null Module");
const DataLayout &DL = M->getDataLayout();
Function *GuardDecl = M->getFunction(
Intrinsic::getName(Intrinsic::experimental_guard));
PImpl = new LazyValueInfoImpl(AC, DL, GuardDecl);
}
return *static_cast<LazyValueInfoImpl*>(PImpl);
}

bool LazyValueInfoWrapperPass::runOnFunction(Function &F) {
Info.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
Info.TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);

if (Info.PImpl)
getImpl(Info.PImpl, Info.AC, F.getParent()).clear();
if (auto *Impl = Info.getImpl())
Impl->clear();

// Fully lazy.
return false;
Expand All @@ -1574,12 +1562,30 @@ void LazyValueInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {

LazyValueInfo &LazyValueInfoWrapperPass::getLVI() { return Info; }

/// This lazily constructs the LazyValueInfoImpl.
LazyValueInfoImpl &LazyValueInfo::getOrCreateImpl(const Module *M) {
if (!PImpl) {
assert(M && "getCache() called with a null Module");
const DataLayout &DL = M->getDataLayout();
Function *GuardDecl =
M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard));
PImpl = new LazyValueInfoImpl(AC, DL, GuardDecl);
}
return *static_cast<LazyValueInfoImpl *>(PImpl);
}

LazyValueInfoImpl *LazyValueInfo::getImpl() {
if (!PImpl)
return nullptr;
return static_cast<LazyValueInfoImpl *>(PImpl);
}

LazyValueInfo::~LazyValueInfo() { releaseMemory(); }

void LazyValueInfo::releaseMemory() {
// If the cache was allocated, free it.
if (PImpl) {
delete &getImpl(PImpl, AC, nullptr);
if (auto *Impl = getImpl()) {
delete &*Impl;
PImpl = nullptr;
}
}
Expand Down Expand Up @@ -1626,7 +1632,7 @@ Constant *LazyValueInfo::getConstant(Value *V, Instruction *CxtI) {

BasicBlock *BB = CxtI->getParent();
ValueLatticeElement Result =
getImpl(PImpl, AC, BB->getModule()).getValueInBlock(V, BB, CxtI);
getOrCreateImpl(BB->getModule()).getValueInBlock(V, BB, CxtI);

if (Result.isConstant())
return Result.getConstant();
Expand All @@ -1644,7 +1650,7 @@ ConstantRange LazyValueInfo::getConstantRange(Value *V, Instruction *CxtI,
unsigned Width = V->getType()->getIntegerBitWidth();
BasicBlock *BB = CxtI->getParent();
ValueLatticeElement Result =
getImpl(PImpl, AC, BB->getModule()).getValueInBlock(V, BB, CxtI);
getOrCreateImpl(BB->getModule()).getValueInBlock(V, BB, CxtI);
if (Result.isUnknown())
return ConstantRange::getEmpty(Width);
if (Result.isConstantRange(UndefAllowed))
Expand Down Expand Up @@ -1710,7 +1716,7 @@ Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB,
Instruction *CxtI) {
Module *M = FromBB->getModule();
ValueLatticeElement Result =
getImpl(PImpl, AC, M).getValueOnEdge(V, FromBB, ToBB, CxtI);
getOrCreateImpl(M).getValueOnEdge(V, FromBB, ToBB, CxtI);

if (Result.isConstant())
return Result.getConstant();
Expand All @@ -1729,7 +1735,7 @@ ConstantRange LazyValueInfo::getConstantRangeOnEdge(Value *V,
unsigned Width = V->getType()->getIntegerBitWidth();
Module *M = FromBB->getModule();
ValueLatticeElement Result =
getImpl(PImpl, AC, M).getValueOnEdge(V, FromBB, ToBB, CxtI);
getOrCreateImpl(M).getValueOnEdge(V, FromBB, ToBB, CxtI);

if (Result.isUnknown())
return ConstantRange::getEmpty(Width);
Expand Down Expand Up @@ -1815,7 +1821,7 @@ LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C,
Instruction *CxtI) {
Module *M = FromBB->getModule();
ValueLatticeElement Result =
getImpl(PImpl, AC, M).getValueOnEdge(V, FromBB, ToBB, CxtI);
getOrCreateImpl(M).getValueOnEdge(V, FromBB, ToBB, CxtI);

return getPredicateResult(Pred, C, Result, M->getDataLayout(), TLI);
}
Expand All @@ -1837,9 +1843,10 @@ LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C,
return LazyValueInfo::True;
}

ValueLatticeElement Result = UseBlockValue
? getImpl(PImpl, AC, M).getValueInBlock(V, CxtI->getParent(), CxtI)
: getImpl(PImpl, AC, M).getValueAt(V, CxtI);
auto &Impl = getOrCreateImpl(M);
ValueLatticeElement Result =
UseBlockValue ? Impl.getValueInBlock(V, CxtI->getParent(), CxtI)
: Impl.getValueAt(V, CxtI);
Tristate Ret = getPredicateResult(Pred, C, Result, DL, TLI);
if (Ret != Unknown)
return Ret;
Expand Down Expand Up @@ -1943,12 +1950,12 @@ LazyValueInfo::Tristate LazyValueInfo::getPredicateAt(unsigned P, Value *LHS,
if (UseBlockValue) {
Module *M = CxtI->getModule();
ValueLatticeElement L =
getImpl(PImpl, AC, M).getValueInBlock(LHS, CxtI->getParent(), CxtI);
getOrCreateImpl(M).getValueInBlock(LHS, CxtI->getParent(), CxtI);
if (L.isOverdefined())
return LazyValueInfo::Unknown;

ValueLatticeElement R =
getImpl(PImpl, AC, M).getValueInBlock(RHS, CxtI->getParent(), CxtI);
getOrCreateImpl(M).getValueInBlock(RHS, CxtI->getParent(), CxtI);
Type *Ty = CmpInst::makeCmpResultType(LHS->getType());
if (Constant *Res = L.getCompare((CmpInst::Predicate)P, Ty, R,
M->getDataLayout())) {
Expand All @@ -1963,28 +1970,28 @@ LazyValueInfo::Tristate LazyValueInfo::getPredicateAt(unsigned P, Value *LHS,

void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc,
BasicBlock *NewSucc) {
if (PImpl) {
getImpl(PImpl, AC, PredBB->getModule())
.threadEdge(PredBB, OldSucc, NewSucc);
}
if (auto *Impl = getImpl())
Impl->threadEdge(PredBB, OldSucc, NewSucc);
}

void LazyValueInfo::forgetValue(Value *V) {
if (auto *Impl = getImpl())
getImpl()->forgetValue(V);
}

void LazyValueInfo::eraseBlock(BasicBlock *BB) {
if (PImpl) {
getImpl(PImpl, AC, BB->getModule()).eraseBlock(BB);
}
if (auto *Impl = getImpl())
getImpl()->eraseBlock(BB);
}

void LazyValueInfo::clear(const Module *M) {
if (PImpl) {
getImpl(PImpl, AC, M).clear();
}
void LazyValueInfo::clear() {
if (auto *Impl = getImpl())
getImpl()->clear();
}

void LazyValueInfo::printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) {
if (PImpl) {
getImpl(PImpl, AC, F.getParent()).printLVI(F, DTree, OS);
}
if (auto *Impl = getImpl())
getImpl()->printLVI(F, DTree, OS);
}

// Print the LVI for the function arguments at the start of each basic block.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Scalar/JumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) {
if (IsLoadCSE) {
LoadInst *NLoadI = cast<LoadInst>(AvailableVal);
combineMetadataForCSE(NLoadI, LoadI, false);
LVI->forgetValue(NLoadI);
};

// If the returned value is the load itself, replace with poison. This can
Expand Down Expand Up @@ -1461,6 +1462,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) {

for (LoadInst *PredLoadI : CSELoads) {
combineMetadataForCSE(PredLoadI, LoadI, true);
LVI->forgetValue(PredLoadI);
}

LoadI->replaceAllUsesWith(PN);
Expand Down
59 changes: 59 additions & 0 deletions llvm/test/Transforms/JumpThreading/invalidate-lvi.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
; RUN: opt -S -passes=jump-threading < %s | FileCheck %s

declare void @set_value(ptr)

declare void @bar()

define void @foo(i1 %0) {
; CHECK-LABEL: define void @foo(
; CHECK-SAME: i1 [[TMP0:%.*]]) {
; CHECK-NEXT: start:
; CHECK-NEXT: [[V:%.*]] = alloca i64, align 8
; CHECK-NEXT: call void @set_value(ptr [[V]])
; CHECK-NEXT: [[L1:%.*]] = load i64, ptr [[V]], align 8
; CHECK-NEXT: br i1 [[TMP0]], label [[BB0:%.*]], label [[BB2:%.*]]
; CHECK: bb0:
; CHECK-NEXT: [[C1:%.*]] = icmp eq i64 [[L1]], 0
; CHECK-NEXT: br i1 [[C1]], label [[BB2_THREAD:%.*]], label [[BB2]]
; CHECK: bb2.thread:
; CHECK-NEXT: store i64 0, ptr [[V]], align 8
; CHECK-NEXT: br label [[BB4:%.*]]
; CHECK: bb2:
; CHECK-NEXT: [[L2:%.*]] = phi i64 [ [[L1]], [[BB0]] ], [ [[L1]], [[START:%.*]] ]
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[L2]], 2
; CHECK-NEXT: br i1 [[TMP1]], label [[BB3:%.*]], label [[BB4]]
; CHECK: bb3:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: ret void
; CHECK: bb4:
; CHECK-NEXT: ret void
;
start:
%v = alloca i64, align 8
call void @set_value(ptr %v)
%l1 = load i64, ptr %v, align 8, !range !0
br i1 %0, label %bb0, label %bb2

bb0: ; preds = %start
%c1 = icmp eq i64 %l1, 0
br i1 %c1, label %bb1, label %bb2

bb1: ; preds = %bb0
store i64 0, ptr %v, align 8
br label %bb2

bb2: ; preds = %bb1, %bb0, %start
%l2 = load i64, ptr %v, align 8
%1 = icmp eq i64 %l2, 2
br i1 %1, label %bb3, label %bb4

bb3: ; preds = %bb2
call void @bar()
ret void

bb4: ; preds = %bb2
ret void
}

!0 = !{i64 0, i64 2}