Skip to content

Commit

Permalink
[Attributor] Deduction based on path exploration
Browse files Browse the repository at this point in the history
This patch introduces the propagation of known information based on path exploration.
For example,
```
int u(int c, int *p){
  if(c) {
     return *p;
  } else {
     return *p + 1;
  }
}
```
An argument `p` is dereferenced whatever c's value is.

For an instruction `CtxI`, we accumulate branch instructions in the must-be-executed-context of `CtxI` and then, we take the conjunction of the successors' known state.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D65593
  • Loading branch information
uenoku committed Mar 9, 2020
1 parent 5e080df commit bdcbdb4
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 26 deletions.
12 changes: 12 additions & 0 deletions llvm/include/llvm/Analysis/MustExecute.h
Expand Up @@ -461,6 +461,18 @@ struct MustBeExecutedContextExplorer {
}
///}

/// Check \p Pred on all instructions in the context.
///
/// This method will evaluate \p Pred and return
/// true if \p Pred holds in every instruction.
bool checkForAllContext(const Instruction *PP,
const function_ref<bool(const Instruction *)> &Pred) {
for (auto EIt = begin(PP), EEnd = end(PP); EIt != EEnd; EIt++)
if (!Pred(*EIt))
return false;
return true;
}

/// Helper to look for \p I in the context of \p PP.
///
/// The context is expanded until \p I was found or no more expansion is
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Expand Up @@ -1339,6 +1339,13 @@ struct IntegerStateBase : public AbstractState {
handleNewAssumedValue(R.getAssumed());
}

/// "Clamp" this state with \p R. The result is subtype dependent but it is
/// intended that information known in either state will be known in
/// this one afterwards.
void operator+=(const IntegerStateBase<base_t, BestState, WorstState> &R) {
handleNewKnownValue(R.getKnown());
}

void operator|=(const IntegerStateBase<base_t, BestState, WorstState> &R) {
joinOR(R.getAssumed(), R.getKnown());
}
Expand Down Expand Up @@ -2294,6 +2301,13 @@ struct DerefState : AbstractState {
return *this;
}

/// See IntegerStateBase::operator+=
DerefState operator+=(const DerefState &R) {
DerefBytesState += R.DerefBytesState;
GlobalState += R.GlobalState;
return *this;
}

/// See IntegerStateBase::operator&=
DerefState operator&=(const DerefState &R) {
DerefBytesState &= R.DerefBytesState;
Expand Down
117 changes: 98 additions & 19 deletions llvm/lib/Transforms/IPO/Attributor.cpp
Expand Up @@ -980,6 +980,23 @@ struct AAFromMustBeExecutedContext : public Base {
Uses.insert(&U);
}

/// Helper function to accumulate uses.
void followUsesInContext(Attributor &A,
MustBeExecutedContextExplorer &Explorer,
const Instruction *CtxI,
SetVector<const Use *> &Uses, StateType &State) {
auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
for (unsigned u = 0; u < Uses.size(); ++u) {
const Use *U = Uses[u];
if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
bool Found = Explorer.findInContextOf(UserI, EIt, EEnd);
if (Found && Base::followUse(A, U, UserI, State))
for (const Use &Us : UserI->uses())
Uses.insert(&Us);
}
}
}

/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
auto BeforeState = this->getState();
Expand All @@ -991,15 +1008,74 @@ struct AAFromMustBeExecutedContext : public Base {
MustBeExecutedContextExplorer &Explorer =
A.getInfoCache().getMustBeExecutedContextExplorer();

auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
for (unsigned u = 0; u < Uses.size(); ++u) {
const Use *U = Uses[u];
if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
bool Found = Explorer.findInContextOf(UserI, EIt, EEnd);
if (Found && Base::followUse(A, U, UserI))
for (const Use &Us : UserI->uses())
Uses.insert(&Us);
followUsesInContext(A, Explorer, CtxI, Uses, S);

if (this->isAtFixpoint())
return ChangeStatus::CHANGED;

SmallVector<const BranchInst *, 4> BrInsts;
auto Pred = [&](const Instruction *I) {
if (const BranchInst *Br = dyn_cast<BranchInst>(I))
if (Br->isConditional())
BrInsts.push_back(Br);
return true;
};

// Here, accumulate conditional branch instructions in the context. We
// explore the child paths and collect the known states. The disjunction of
// those states can be merged to its own state. Let ParentState_i be a state
// to indicate the known information for an i-th branch instruction in the
// context. ChildStates are created for its successors respectively.
//
// ParentS_1 = ChildS_{1, 1} /\ ChildS_{1, 2} /\ ... /\ ChildS_{1, n_1}
// ParentS_2 = ChildS_{2, 1} /\ ChildS_{2, 2} /\ ... /\ ChildS_{2, n_2}
// ...
// ParentS_m = ChildS_{m, 1} /\ ChildS_{m, 2} /\ ... /\ ChildS_{m, n_m}
//
// Known State |= ParentS_1 \/ ParentS_2 \/... \/ ParentS_m
//
// FIXME: Currently, recursive branches are not handled. For example, we
// can't deduce that ptr must be dereferenced in below function.
//
// void f(int a, int c, int *ptr) {
// if(a)
// if (b) {
// *ptr = 0;
// } else {
// *ptr = 1;
// }
// else {
// if (b) {
// *ptr = 0;
// } else {
// *ptr = 1;
// }
// }
// }

Explorer.checkForAllContext(CtxI, Pred);
for (const BranchInst *Br : BrInsts) {
StateType ParentState;

// The known state of the parent state is a conjunction of children's
// known states so it is initialized with a best state.
ParentState.indicateOptimisticFixpoint();

for (const BasicBlock *BB : Br->successors()) {
StateType ChildState;

size_t BeforeSize = Uses.size();
followUsesInContext(A, Explorer, &BB->front(), Uses, ChildState);

// Erase uses which only appear in the child.
for (auto It = Uses.begin() + BeforeSize; It != Uses.end();)
It = Uses.erase(It);

ParentState &= ChildState;
}

// Use only known state.
S += ParentState;
}

return BeforeState == S ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED;
Expand Down Expand Up @@ -1900,7 +1976,7 @@ struct AANoFreeCallSiteReturned final : AANoFreeFloating {

/// ------------------------ NonNull Argument Attribute ------------------------
static int64_t getKnownNonNullAndDerefBytesForUse(
Attributor &A, AbstractAttribute &QueryingAA, Value &AssociatedValue,
Attributor &A, const AbstractAttribute &QueryingAA, Value &AssociatedValue,
const Use *U, const Instruction *I, bool &IsNonNull, bool &TrackUse) {
TrackUse = false;

Expand Down Expand Up @@ -1991,12 +2067,13 @@ struct AANonNullImpl : AANonNull {
}

/// See AAFromMustBeExecutedContext
bool followUse(Attributor &A, const Use *U, const Instruction *I) {
bool followUse(Attributor &A, const Use *U, const Instruction *I,
AANonNull::StateType &State) {
bool IsNonNull = false;
bool TrackUse = false;
getKnownNonNullAndDerefBytesForUse(A, *this, getAssociatedValue(), U, I,
IsNonNull, TrackUse);
setKnown(IsNonNull);
State.setKnown(IsNonNull);
return TrackUse;
}

Expand Down Expand Up @@ -3549,8 +3626,8 @@ struct AADereferenceableImpl : AADereferenceable {
/// }

/// Helper function for collecting accessed bytes in must-be-executed-context
void addAccessedBytesForUse(Attributor &A, const Use *U,
const Instruction *I) {
void addAccessedBytesForUse(Attributor &A, const Use *U, const Instruction *I,
DerefState &State) {
const Value *UseV = U->get();
if (!UseV->getType()->isPointerTy())
return;
Expand All @@ -3563,21 +3640,22 @@ struct AADereferenceableImpl : AADereferenceable {
if (Base == &getAssociatedValue() &&
getPointerOperand(I, /* AllowVolatile */ false) == UseV) {
uint64_t Size = DL.getTypeStoreSize(PtrTy->getPointerElementType());
addAccessedBytes(Offset, Size);
State.addAccessedBytes(Offset, Size);
}
}
return;
}

/// See AAFromMustBeExecutedContext
bool followUse(Attributor &A, const Use *U, const Instruction *I) {
bool followUse(Attributor &A, const Use *U, const Instruction *I,
AADereferenceable::StateType &State) {
bool IsNonNull = false;
bool TrackUse = false;
int64_t DerefBytes = getKnownNonNullAndDerefBytesForUse(
A, *this, getAssociatedValue(), U, I, IsNonNull, TrackUse);

addAccessedBytesForUse(A, U, I);
takeKnownDerefBytesMaximum(DerefBytes);
addAccessedBytesForUse(A, U, I, State);
State.takeKnownDerefBytesMaximum(DerefBytes);
return TrackUse;
}

Expand Down Expand Up @@ -3871,12 +3949,13 @@ struct AAAlignImpl : AAAlign {
Attribute::getWithAlignment(Ctx, Align(getAssumedAlign())));
}
/// See AAFromMustBeExecutedContext
bool followUse(Attributor &A, const Use *U, const Instruction *I) {
bool followUse(Attributor &A, const Use *U, const Instruction *I,
AAAlign::StateType &State) {
bool TrackUse = false;

unsigned int KnownAlign =
getKnownAlignForUse(A, *this, getAssociatedValue(), U, I, TrackUse);
takeKnownMaximum(KnownAlign);
State.takeKnownMaximum(KnownAlign);

return TrackUse;
}
Expand Down
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=1 < %s | FileCheck %s
; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=2 < %s | FileCheck %s
;
; void bar(int, float, double);
;
Expand Down
147 changes: 147 additions & 0 deletions llvm/test/Transforms/Attributor/dereferenceable-1.ll
Expand Up @@ -308,5 +308,152 @@ entry:
ret void
}

declare void @use0() willreturn nounwind
declare void @use1(i8*) willreturn nounwind
declare void @use2(i8*, i8*) willreturn nounwind
declare void @use3(i8*, i8*, i8*) willreturn nounwind
; simple path test
; if(..)
; fun2(dereferenceable(8) %a, dereferenceable(8) %b)
; else
; fun2(dereferenceable(4) %a, %b)
; We can say that %a is dereferenceable(4) but %b is not.
define void @simple-path(i8* %a, i8 * %b, i8 %c) {
; ATTRIBUTOR: define void @simple-path(i8* nonnull dereferenceable(4) %a, i8* %b, i8 %c)
%cmp = icmp eq i8 %c, 0
br i1 %cmp, label %if.then, label %if.else
if.then:
tail call void @use2(i8* dereferenceable(8) %a, i8* dereferenceable(8) %b)
ret void
if.else:
tail call void @use2(i8* dereferenceable(4) %a, i8* %b)
ret void
}
; More complex test
; {
; fun1(dereferenceable(4) %a)
; if(..)
; ... (willreturn & nounwind)
; fun1(dereferenceable(12) %a)
; else
; ... (willreturn & nounwind)
; fun1(dereferenceable(16) %a)
; fun1(dereferenceable(8) %a)
; }
; %a is dereferenceable(12)

define void @complex-path(i8* %a, i8* %b, i8 %c) {
; ATTRIBUTOR: define void @complex-path(i8* nonnull dereferenceable(12) %a, i8* nocapture nofree readnone %b, i8 %c)
%cmp = icmp eq i8 %c, 0
tail call void @use1(i8* dereferenceable(4) %a)
br i1 %cmp, label %cont.then, label %cont.else
cont.then:
tail call void @use1(i8* dereferenceable(12) %a)
br label %cont2
cont.else:
tail call void @use1(i8* dereferenceable(16) %a)
br label %cont2
cont2:
tail call void @use1(i8* dereferenceable(8) %a)
ret void
}

; void rec-branch-1(int a, int b, int c, int *ptr) {
; if (a) {
; if (b)
; *ptr = 1;
; else
; *ptr = 2;
; } else {
; if (c)
; *ptr = 3;
; else
; *ptr = 4;
; }
; }
;
; FIXME: %ptr should be dereferenceable(4)
; ATTRIBUTOR: define dso_local void @rec-branch-1(i32 %a, i32 %b, i32 %c, i32* nocapture nofree writeonly %ptr)
define dso_local void @rec-branch-1(i32 %a, i32 %b, i32 %c, i32* %ptr) {
entry:
%tobool = icmp eq i32 %a, 0
br i1 %tobool, label %if.else3, label %if.then

if.then: ; preds = %entry
%tobool1 = icmp eq i32 %b, 0
br i1 %tobool1, label %if.else, label %if.then2

if.then2: ; preds = %if.then
store i32 1, i32* %ptr, align 4
br label %if.end8

if.else: ; preds = %if.then
store i32 2, i32* %ptr, align 4
br label %if.end8

if.else3: ; preds = %entry
%tobool4 = icmp eq i32 %c, 0
br i1 %tobool4, label %if.else6, label %if.then5

if.then5: ; preds = %if.else3
store i32 3, i32* %ptr, align 4
br label %if.end8

if.else6: ; preds = %if.else3
store i32 4, i32* %ptr, align 4
br label %if.end8

if.end8: ; preds = %if.then5, %if.else6, %if.then2, %if.else
ret void
}

; void rec-branch-2(int a, int b, int c, int *ptr) {
; if (a) {
; if (b)
; *ptr = 1;
; else
; *ptr = 2;
; } else {
; if (c)
; *ptr = 3;
; else
; rec-branch-2(1, 1, 1, ptr);
; }
; }
; FIXME: %ptr should be dereferenceable(4)
; ATTRIBUTOR: define dso_local void @rec-branch-2(i32 %a, i32 %b, i32 %c, i32* nocapture nofree writeonly %ptr)
define dso_local void @rec-branch-2(i32 %a, i32 %b, i32 %c, i32* %ptr) {
entry:
%tobool = icmp eq i32 %a, 0
br i1 %tobool, label %if.else3, label %if.then

if.then: ; preds = %entry
%tobool1 = icmp eq i32 %b, 0
br i1 %tobool1, label %if.else, label %if.then2

if.then2: ; preds = %if.then
store i32 1, i32* %ptr, align 4
br label %if.end8

if.else: ; preds = %if.then
store i32 2, i32* %ptr, align 4
br label %if.end8

if.else3: ; preds = %entry
%tobool4 = icmp eq i32 %c, 0
br i1 %tobool4, label %if.else6, label %if.then5

if.then5: ; preds = %if.else3
store i32 3, i32* %ptr, align 4
br label %if.end8

if.else6: ; preds = %if.else3
tail call void @rec-branch-2(i32 1, i32 1, i32 1, i32* %ptr)
br label %if.end8

if.end8: ; preds = %if.then5, %if.else6, %if.then2, %if.else
ret void
}

!0 = !{i64 10, i64 100}

0 comments on commit bdcbdb4

Please sign in to comment.