Skip to content

Commit

Permalink
[Attributor] Use the proper context instruction in genericValueTraversal
Browse files Browse the repository at this point in the history
There was a TODO in genericValueTraversal to provide the context
instruction and due to the lack of it users that wanted one just used
something available. Unfortunately, using a fixed instruction is wrong
in the presence of PHIs so we need to update the context instruction
properly.

Reviewed By: uenoku

Differential Revision: https://reviews.llvm.org/D76870
  • Loading branch information
jdoerfert committed Apr 2, 2020
1 parent ac96c8f commit bcd8009
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 43 deletions.
96 changes: 53 additions & 43 deletions llvm/lib/Transforms/IPO/Attributor.cpp
Expand Up @@ -398,8 +398,10 @@ static Value *constructPointer(Type *ResTy, Value *Ptr, int64_t Offset,
template <typename AAType, typename StateTy>
static bool genericValueTraversal(
Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State,
function_ref<bool(Value &, StateTy &, bool)> VisitValueCB,
int MaxValues = 8, function_ref<Value *(Value *)> StripCB = nullptr) {
function_ref<bool(Value &, const Instruction *, StateTy &, bool)>
VisitValueCB,
const Instruction *CtxI, int MaxValues = 16,
function_ref<Value *(Value *)> StripCB = nullptr) {

const AAIsDead *LivenessAA = nullptr;
if (IRP.getAnchorScope())
Expand All @@ -408,20 +410,22 @@ static bool genericValueTraversal(
/* TrackDependence */ false);
bool AnyDead = false;

// TODO: Use Positions here to allow context sensitivity in VisitValueCB
SmallPtrSet<Value *, 16> Visited;
SmallVector<Value *, 16> Worklist;
Worklist.push_back(&IRP.getAssociatedValue());
using Item = std::pair<Value *, const Instruction *>;
SmallSet<Item, 16> Visited;
SmallVector<Item, 16> Worklist;
Worklist.push_back({&IRP.getAssociatedValue(), CtxI});

int Iteration = 0;
do {
Value *V = Worklist.pop_back_val();
Item I = Worklist.pop_back_val();
Value *V = I.first;
CtxI = I.second;
if (StripCB)
V = StripCB(V);

// Check if we should process the current value. To prevent endless
// recursion keep a record of the values we followed!
if (!Visited.insert(V).second)
if (!Visited.insert(I).second)
continue;

// Make sure we limit the compile time for complex expressions.
Expand All @@ -444,14 +448,14 @@ static bool genericValueTraversal(
}
}
if (NewV && NewV != V) {
Worklist.push_back(NewV);
Worklist.push_back({NewV, CtxI});
continue;
}

// Look through select instructions, visit both potential values.
if (auto *SI = dyn_cast<SelectInst>(V)) {
Worklist.push_back(SI->getTrueValue());
Worklist.push_back(SI->getFalseValue());
Worklist.push_back({SI->getTrueValue(), CtxI});
Worklist.push_back({SI->getFalseValue(), CtxI});
continue;
}

Expand All @@ -460,20 +464,21 @@ static bool genericValueTraversal(
assert(LivenessAA &&
"Expected liveness in the presence of instructions!");
for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
const BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
if (A.isAssumedDead(*IncomingBB->getTerminator(), &QueryingAA,
LivenessAA,
/* CheckBBLivenessOnly */ true)) {
AnyDead = true;
continue;
}
Worklist.push_back(PHI->getIncomingValue(u));
Worklist.push_back(
{PHI->getIncomingValue(u), IncomingBB->getTerminator()});
}
continue;
}

// Once a leaf is reached we inform the user through the callback.
if (!VisitValueCB(*V, State, Iteration > 1))
if (!VisitValueCB(*V, CtxI, State, Iteration > 1))
return false;
} while (!Worklist.empty());

Expand Down Expand Up @@ -710,7 +715,7 @@ void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs,
}
if (A)
for (Attribute::AttrKind AK : AKs)
getAttrsFromAssumes(AK, Attrs, *A);
getAttrsFromAssumes(AK, Attrs, *A);
}

bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK,
Expand Down Expand Up @@ -1466,7 +1471,8 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
};

// Callback for a leaf value returned by the associated function.
auto VisitValueCB = [](Value &Val, RVState &RVS, bool) -> bool {
auto VisitValueCB = [](Value &Val, const Instruction *, RVState &RVS,
bool) -> bool {
auto Size = RVS.RetValsMap[&Val].size();
RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end());
bool Inserted = RVS.RetValsMap[&Val].size() != Size;
Expand All @@ -1480,18 +1486,19 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
};

// Helper method to invoke the generic value traversal.
auto VisitReturnedValue = [&](Value &RV, RVState &RVS) {
auto VisitReturnedValue = [&](Value &RV, RVState &RVS,
const Instruction *CtxI) {
IRPosition RetValPos = IRPosition::value(RV);
return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this,
RVS, VisitValueCB);
return genericValueTraversal<AAReturnedValues, RVState>(
A, RetValPos, *this, RVS, VisitValueCB, CtxI);
};

// Callback for all "return intructions" live in the associated function.
auto CheckReturnInst = [this, &VisitReturnedValue, &Changed](Instruction &I) {
ReturnInst &Ret = cast<ReturnInst>(I);
RVState RVS({ReturnedValues, Changed, {}});
RVS.RetInsts.insert(&Ret);
return VisitReturnedValue(*Ret.getReturnValue(), RVS);
return VisitReturnedValue(*Ret.getReturnValue(), RVS, &I);
};

// Start by discovering returned values from all live returned instructions in
Expand Down Expand Up @@ -1576,7 +1583,7 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
// again.
bool Unused = false;
RVState RVS({NewRVsMap, Unused, RetValAAIt.second});
VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS);
VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS, CB);
continue;
} else if (isa<CallBase>(RetVal)) {
// Call sites are resolved by the callee attribute over time, no need to
Expand Down Expand Up @@ -2148,11 +2155,11 @@ struct AANonNullFloating
AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn);
}

auto VisitValueCB = [&](Value &V, AANonNull::StateType &T,
bool Stripped) -> bool {
auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
AANonNull::StateType &T, bool Stripped) -> bool {
const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V));
if (!Stripped && this == &AA) {
if (!isKnownNonZero(&V, DL, 0, AC, getCtxI(), DT))
if (!isKnownNonZero(&V, DL, 0, AC, CtxI, DT))
T.indicatePessimisticFixpoint();
} else {
// Use abstract attribute information.
Expand All @@ -2164,8 +2171,8 @@ struct AANonNullFloating
};

StateType T;
if (!genericValueTraversal<AANonNull, StateType>(A, getIRPosition(), *this,
T, VisitValueCB))
if (!genericValueTraversal<AANonNull, StateType>(
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
return indicatePessimisticFixpoint();

return clampStateAndIndicateChange(getState(), T);
Expand Down Expand Up @@ -3776,7 +3783,8 @@ struct AADereferenceableFloating

const DataLayout &DL = A.getDataLayout();

auto VisitValueCB = [&](Value &V, DerefState &T, bool Stripped) -> bool {
auto VisitValueCB = [&](Value &V, const Instruction *, DerefState &T,
bool Stripped) -> bool {
unsigned IdxWidth =
DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace());
APInt Offset(IdxWidth, 0);
Expand Down Expand Up @@ -3831,7 +3839,7 @@ struct AADereferenceableFloating

DerefState T;
if (!genericValueTraversal<AADereferenceable, DerefState>(
A, getIRPosition(), *this, T, VisitValueCB))
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
return indicatePessimisticFixpoint();

return Change | clampStateAndIndicateChange(getState(), T);
Expand Down Expand Up @@ -4073,8 +4081,8 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {

const DataLayout &DL = A.getDataLayout();

auto VisitValueCB = [&](Value &V, AAAlign::StateType &T,
bool Stripped) -> bool {
auto VisitValueCB = [&](Value &V, const Instruction *,
AAAlign::StateType &T, bool Stripped) -> bool {
const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V));
if (!Stripped && this == &AA) {
// Use only IR information if we did not strip anything.
Expand All @@ -4092,7 +4100,7 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {

StateType T;
if (!genericValueTraversal<AAAlign, StateType>(A, getIRPosition(), *this, T,
VisitValueCB))
VisitValueCB, getCtxI()))
return indicatePessimisticFixpoint();

// TODO: If we know we visited all incoming values, thus no are assumed
Expand Down Expand Up @@ -4958,7 +4966,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
ChangeStatus updateImpl(Attributor &A) override {
bool HasValueBefore = SimplifiedAssociatedValue.hasValue();

auto VisitValueCB = [&](Value &V, bool &, bool Stripped) -> bool {
auto VisitValueCB = [&](Value &V, const Instruction *CtxI, bool &,
bool Stripped) -> bool {
auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V));
if (!Stripped && this == &AA) {
// TODO: Look the instruction and check recursively.
Expand All @@ -4971,8 +4980,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
};

bool Dummy = false;
if (!genericValueTraversal<AAValueSimplify, bool>(A, getIRPosition(), *this,
Dummy, VisitValueCB))
if (!genericValueTraversal<AAValueSimplify, bool>(
A, getIRPosition(), *this, Dummy, VisitValueCB, getCtxI()))
if (!askSimplifiedValueForAAValueConstantRange(A))
return indicatePessimisticFixpoint();

Expand Down Expand Up @@ -6605,7 +6614,8 @@ void AAMemoryLocationImpl::categorizePtrValue(
return V;
};

auto VisitValueCB = [&](Value &V, AAMemoryLocation::StateType &T,
auto VisitValueCB = [&](Value &V, const Instruction *,
AAMemoryLocation::StateType &T,
bool Stripped) -> bool {
assert(!isa<GEPOperator>(V) && "GEPs should have been stripped.");
if (isa<UndefValue>(V))
Expand Down Expand Up @@ -6652,7 +6662,7 @@ void AAMemoryLocationImpl::categorizePtrValue(
};

if (!genericValueTraversal<AAMemoryLocation, AAMemoryLocation::StateType>(
A, IRPosition::value(Ptr), *this, State, VisitValueCB,
A, IRPosition::value(Ptr), *this, State, VisitValueCB, getCtxI(),
/* MaxValues */ 32, StripGEPCB)) {
LLVM_DEBUG(
dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n");
Expand Down Expand Up @@ -7132,7 +7142,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {

bool calculateBinaryOperator(
Attributor &A, BinaryOperator *BinOp, IntegerRangeState &T,
Instruction *CtxI,
const Instruction *CtxI,
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
Value *LHS = BinOp->getOperand(0);
Value *RHS = BinOp->getOperand(1);
Expand Down Expand Up @@ -7160,7 +7170,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
}

bool calculateCastInst(
Attributor &A, CastInst *CastI, IntegerRangeState &T, Instruction *CtxI,
Attributor &A, CastInst *CastI, IntegerRangeState &T,
const Instruction *CtxI,
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
assert(CastI->getNumOperands() == 1 && "Expected cast to be unary!");
// TODO: Allow non integers as well.
Expand All @@ -7178,7 +7189,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {

bool
calculateCmpInst(Attributor &A, CmpInst *CmpI, IntegerRangeState &T,
Instruction *CtxI,
const Instruction *CtxI,
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
Value *LHS = CmpI->getOperand(0);
Value *RHS = CmpI->getOperand(1);
Expand Down Expand Up @@ -7233,9 +7244,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {

/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
Instruction *CtxI = getCtxI();
auto VisitValueCB = [&](Value &V, IntegerRangeState &T,
bool Stripped) -> bool {
auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
IntegerRangeState &T, bool Stripped) -> bool {
Instruction *I = dyn_cast<Instruction>(&V);
if (!I || isa<CallBase>(I)) {

Expand Down Expand Up @@ -7285,7 +7295,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
IntegerRangeState T(getBitWidth());

if (!genericValueTraversal<AAValueConstantRange, IntegerRangeState>(
A, getIRPosition(), *this, T, VisitValueCB))
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
return indicatePessimisticFixpoint();

return clampStateAndIndicateChange(getState(), T);
Expand Down
70 changes: 70 additions & 0 deletions llvm/test/Transforms/Attributor/range.ll
Expand Up @@ -1212,6 +1212,76 @@ define i1 @callee_range_2(i1 %c1, i1 %c2) {
}


define i32 @ret100() {
; CHECK-LABEL: define {{[^@]+}}@ret100()
; CHECK-NEXT: ret i32 100
;
ret i32 100
}

define i1 @ctx_adjustment(i32 %V) {
; OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
; OLD_PM-SAME: (i32 [[V:%.*]])
; OLD_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
; OLD_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
; OLD_PM: if.true:
; OLD_PM-NEXT: br label [[END:%.*]]
; OLD_PM: if.false:
; OLD_PM-NEXT: br label [[END]]
; OLD_PM: end:
; OLD_PM-NEXT: [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
; OLD_PM-NEXT: [[C2:%.*]] = icmp sge i32 [[PHI]], 100
; OLD_PM-NEXT: ret i1 [[C2]]
;
; NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
; NEW_PM-SAME: (i32 [[V:%.*]])
; NEW_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
; NEW_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
; NEW_PM: if.true:
; NEW_PM-NEXT: br label [[END:%.*]]
; NEW_PM: if.false:
; NEW_PM-NEXT: br label [[END]]
; NEW_PM: end:
; NEW_PM-NEXT: ret i1 true
;
; CGSCC_OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
; CGSCC_OLD_PM-SAME: (i32 [[V:%.*]])
; CGSCC_OLD_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
; CGSCC_OLD_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
; CGSCC_OLD_PM: if.true:
; CGSCC_OLD_PM-NEXT: br label [[END:%.*]]
; CGSCC_OLD_PM: if.false:
; CGSCC_OLD_PM-NEXT: br label [[END]]
; CGSCC_OLD_PM: end:
; CGSCC_OLD_PM-NEXT: [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
; CGSCC_OLD_PM-NEXT: [[C2:%.*]] = icmp sge i32 [[PHI]], 100
; CGSCC_OLD_PM-NEXT: ret i1 [[C2]]
;
; CGSCC_NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
; CGSCC_NEW_PM-SAME: (i32 [[V:%.*]])
; CGSCC_NEW_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
; CGSCC_NEW_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
; CGSCC_NEW_PM: if.true:
; CGSCC_NEW_PM-NEXT: br label [[END:%.*]]
; CGSCC_NEW_PM: if.false:
; CGSCC_NEW_PM-NEXT: br label [[END]]
; CGSCC_NEW_PM: end:
; CGSCC_NEW_PM-NEXT: ret i1 true
;
%c1 = icmp sge i32 %V, 100
br i1 %c1, label %if.true, label %if.false
if.true:
br label %end
if.false:
%call = call i32 @ret100()
br label %end
end:
%phi = phi i32 [ %V, %if.true ], [ %call, %if.false ]
%c2 = icmp sge i32 %phi, 100
ret i1 %c2
}


!0 = !{i32 0, i32 10}
!1 = !{i32 10, i32 100}
; CHECK: !0 = !{i32 0, i32 10}
Expand Down

0 comments on commit bcd8009

Please sign in to comment.