Skip to content

Commit

Permalink
[DSE] Support traversing MemoryPhis.
Browse files Browse the repository at this point in the history
For MemoryPhis, we have to avoid that the MemoryPhi may be executed
before before the access we are currently looking at.

To do this we do a post-order numbering of the basic blocks in the
function and bail out once we reach a MemoryPhi with a larger (or equal)
post-order block number than the current MemoryAccess.
This changes the order in which we visit stores for elimination.

This patch also adds support for exploring multiple paths. We keep a worklist (ToCheck) of memory accesses that might be eliminated by our starting MemoryDef or MemoryPhis for further exploration.  For MemoryPhis, we add the incoming values to the worklist, for MemoryDefs we add the defining access.

Reviewers: dmgreen, rnk, efriedma, bryant, asbirlea

Reviewed By: asbirlea

Differential Revision: https://reviews.llvm.org/D72148
  • Loading branch information
fhahn committed Mar 20, 2020
1 parent 2cbb8c9 commit 3a8372e
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 71 deletions.
163 changes: 105 additions & 58 deletions llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
Expand All @@ -42,7 +43,6 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
Expand Down Expand Up @@ -1487,6 +1487,9 @@ struct DSEState {
SmallPtrSet<const Value *, 16> InvisibleToCaller;
// Keep track of blocks with throwing instructions not modeled in MemorySSA.
SmallPtrSet<BasicBlock *, 16> ThrowingBlocks;
// Post-order numbers for each basic block. Used to figure out if memory
// accesses are executed before another access.
DenseMap<BasicBlock *, unsigned> PostOrderNumbers;

/// Keep track of instructions (partly) overlapping with killing MemoryDefs per
/// basic block.
Expand All @@ -1502,23 +1505,28 @@ struct DSEState {
DSEState State(F, AA, MSSA, DT, PDT, TLI);
// Collect blocks with throwing instructions not modeled in MemorySSA and
// alloc-like objects.
for (Instruction &I : instructions(F)) {
if (I.mayThrow() && !MSSA.getMemoryAccess(&I))
State.ThrowingBlocks.insert(I.getParent());

auto *MD = dyn_cast_or_null<MemoryDef>(MSSA.getMemoryAccess(&I));
if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit &&
hasAnalyzableMemoryWrite(&I, TLI) && isRemovable(&I))
State.MemDefs.push_back(MD);

// Track alloca and alloca-like objects. Here we care about objects not
// visible to the caller during function execution. Alloca objects are
// invalid in the caller, for alloca-like objects we ensure that they are
// not captured throughout the function.
if (isa<AllocaInst>(&I) ||
(isAllocLikeFn(&I, &TLI) && !PointerMayBeCaptured(&I, false, true)))
State.InvisibleToCaller.insert(&I);
unsigned PO = 0;
for (BasicBlock *BB : post_order(&F)) {
State.PostOrderNumbers[BB] = PO++;
for (Instruction &I : *BB) {
if (I.mayThrow() && !MSSA.getMemoryAccess(&I))
State.ThrowingBlocks.insert(I.getParent());

auto *MD = dyn_cast_or_null<MemoryDef>(MSSA.getMemoryAccess(&I));
if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit &&
hasAnalyzableMemoryWrite(&I, TLI) && isRemovable(&I))
State.MemDefs.push_back(MD);

// Track alloca and alloca-like objects. Here we care about objects not
// visible to the caller during function execution. Alloca objects are
// invalid in the caller, for alloca-like objects we ensure that they
// are not captured throughout the function.
if (isa<AllocaInst>(&I) ||
(isAllocLikeFn(&I, &TLI) && !PointerMayBeCaptured(&I, false, true)))
State.InvisibleToCaller.insert(&I);
}
}

// Treat byval or inalloca arguments the same as Allocas, stores to them are
// dead at the end of the function.
for (Argument &AI : F.args())
Expand Down Expand Up @@ -1593,16 +1601,13 @@ struct DSEState {
// Find a MemoryDef writing to \p DefLoc and dominating \p Current, with no
// read access in between or return None otherwise. The returned value may not
// (completely) overwrite \p DefLoc. Currently we bail out when we encounter
// any of the following
// * An aliasing MemoryUse (read).
// * A MemoryPHI.
// an aliasing MemoryUse (read).
Optional<MemoryAccess *> getDomMemoryDef(MemoryDef *KillingDef,
MemoryAccess *Current,
MemoryLocation DefLoc,
bool DefVisibleToCaller,
int &ScanLimit) const {
MemoryDef *DomDef;
MemoryAccess *StartDef = Current;
MemoryAccess *DomAccess;
bool StepAgain;
LLVM_DEBUG(dbgs() << " trying to get dominating access for " << *Current
<< "\n");
Expand All @@ -1613,37 +1618,44 @@ struct DSEState {
if (MSSA.isLiveOnEntryDef(Current))
return None;

MemoryUseOrDef *CurrentUD = dyn_cast<MemoryUseOrDef>(Current);
if (!CurrentUD)
return None;

if (isa<MemoryPhi>(Current)) {
DomAccess = Current;
break;
}
MemoryUseOrDef *CurrentUD = cast<MemoryUseOrDef>(Current);
// Look for access that clobber DefLoc.
MemoryAccess *DomAccess =
MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(
CurrentUD->getDefiningAccess(), DefLoc);
DomDef = dyn_cast<MemoryDef>(DomAccess);
if (!DomDef || MSSA.isLiveOnEntryDef(DomDef))
DomAccess = MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(CurrentUD,
DefLoc);
if (MSSA.isLiveOnEntryDef(DomAccess))
return None;

if (isa<MemoryPhi>(DomAccess))
break;

// Check if we can skip DomDef for DSE. We also require the KillingDef
// execute whenever DomDef executes and use post-dominance to ensure that.
if (canSkipDef(DomDef, DefVisibleToCaller) ||

MemoryDef *DomDef = dyn_cast<MemoryDef>(DomAccess);
if ((DomDef && canSkipDef(DomDef, DefVisibleToCaller)) ||
!PDT.dominates(KillingDef->getBlock(), DomDef->getBlock())) {
StepAgain = true;
Current = DomDef;
Current = DomDef->getDefiningAccess();
}

} while (StepAgain);

LLVM_DEBUG(dbgs() << " Checking for reads of " << *DomDef << " ("
<< *DomDef->getMemoryInst() << ")\n");
LLVM_DEBUG({
dbgs() << " Checking for reads of " << *DomAccess;
if (isa<MemoryDef>(DomAccess))
dbgs() << " (" << *cast<MemoryDef>(DomAccess)->getMemoryInst() << ")\n";
});

SmallSetVector<MemoryAccess *, 32> WorkList;
auto PushMemUses = [&WorkList](MemoryAccess *Acc) {
for (Use &U : Acc->uses())
WorkList.insert(cast<MemoryAccess>(U.getUser()));
};
PushMemUses(DomDef);
PushMemUses(DomAccess);

// Check if DomDef may be read.
for (unsigned I = 0; I < WorkList.size(); I++) {
Expand All @@ -1655,10 +1667,9 @@ struct DSEState {
return None;
}

// Bail out on MemoryPhis for now.
if (isa<MemoryPhi>(UseAccess)) {
LLVM_DEBUG(dbgs() << " ... hit MemoryPhi\n");
return None;
PushMemUses(UseAccess);
continue;
}

Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst();
Expand All @@ -1676,7 +1687,11 @@ struct DSEState {
return None;
}

if (StartDef == UseAccess)
// For the KillingDef we only have to check if it reads the memory
// location.
// TODO: It would probably be better to check for self-reads before
// calling the function.
if (KillingDef == UseAccess)
continue;

// Check all uses for MemoryDefs, except for defs completely overwriting
Expand All @@ -1695,8 +1710,8 @@ struct DSEState {
}
}

// No aliasing MemoryUses of DomDef found, DomDef is potentially dead.
return {DomDef};
// No aliasing MemoryUses of DomAccess found, DomAccess is potentially dead.
return {DomAccess};
}

// Delete dead memory defs
Expand Down Expand Up @@ -1788,10 +1803,10 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI);
// For each store:
for (unsigned I = 0; I < State.MemDefs.size(); I++) {
MemoryDef *Current = State.MemDefs[I];
if (State.SkipStores.count(Current))
MemoryDef *KillingDef = State.MemDefs[I];
if (State.SkipStores.count(KillingDef))
continue;
Instruction *SI = cast<MemoryDef>(Current)->getMemoryInst();
Instruction *SI = KillingDef->getMemoryInst();
auto MaybeSILoc = State.getLocForWriteEx(SI);
if (!MaybeSILoc) {
LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
Expand All @@ -1808,22 +1823,54 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
!PointerMayBeCapturedBefore(DefObj, false, true, SI, &DT))))
DefVisibleToCaller = false;

LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " << *SI
<< "\n");
MemoryAccess *Current = KillingDef;
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
<< *KillingDef << " (" << *SI << ")\n");

int ScanLimit = MemorySSAScanLimit;
MemoryDef *StartDef = Current;
// Walk MemorySSA upward to find MemoryDefs that might be killed by SI.
while (Optional<MemoryAccess *> Next = State.getDomMemoryDef(
StartDef, Current, SILoc, DefVisibleToCaller, ScanLimit)) {
// Worklist of MemoryAccesses that may be killed by KillingDef.
SetVector<MemoryAccess *> ToCheck;
ToCheck.insert(KillingDef->getDefiningAccess());

// Check if MemoryAccesses in the worklist are killed by KillingDef.
for (unsigned I = 0; I < ToCheck.size(); I++) {
Current = ToCheck[I];
if (State.SkipStores.count(Current))
continue;

Optional<MemoryAccess *> Next = State.getDomMemoryDef(
KillingDef, Current, SILoc, DefVisibleToCaller, ScanLimit);

if (!Next) {
LLVM_DEBUG(dbgs() << " finished walk\n");
continue;
}

MemoryAccess *DomAccess = *Next;
LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DomAccess << "\n");
if (isa<MemoryPhi>(DomAccess)) {
for (Value *V : cast<MemoryPhi>(DomAccess)->incoming_values()) {
MemoryAccess *IncomingAccess = cast<MemoryAccess>(V);
BasicBlock *IncomingBlock = IncomingAccess->getBlock();
BasicBlock *PhiBlock = DomAccess->getBlock();

// We only consider incoming MemoryAccesses that come before the
// MemoryPhi. Otherwise we could discover candidates that do not
// strictly dominate our starting def.
if (State.PostOrderNumbers[IncomingBlock] >
State.PostOrderNumbers[PhiBlock])
ToCheck.insert(IncomingAccess);
}
continue;
}
MemoryDef *NextDef = dyn_cast<MemoryDef>(DomAccess);
Instruction *NI = NextDef->getMemoryInst();
LLVM_DEBUG(dbgs() << " def " << *NI << "\n");

if (!hasAnalyzableMemoryWrite(NI, TLI))
break;
if (!hasAnalyzableMemoryWrite(NI, TLI)) {
LLVM_DEBUG(dbgs() << " skip, cannot analyze def\n");
continue;
}

if (!isRemovable(NI)) {
LLVM_DEBUG(dbgs() << " skip, cannot remove def\n");
Expand All @@ -1834,14 +1881,14 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
// Check for anything that looks like it will be a barrier to further
// removal
if (State.isDSEBarrier(SI, SILoc, SILocUnd, NI, NILoc)) {
LLVM_DEBUG(dbgs() << " stop, barrier\n");
break;
LLVM_DEBUG(dbgs() << " skip, barrier\n");
continue;
}

// Before we try to remove anything, check for any extra throwing
// instructions that block us from DSEing
if (State.mayThrowBetween(SI, NI, SILocUnd)) {
LLVM_DEBUG(dbgs() << " stop, may throw!\n");
LLVM_DEBUG(dbgs() << " skip, may throw!\n");
break;
}

Expand All @@ -1857,14 +1904,14 @@ bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA,
OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset,
InstWriteOffset, NI, IOL, AA, &F);

ToCheck.insert(NextDef->getDefiningAccess());
if (OR == OW_Complete) {
LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *NI
<< "\n KILLER: " << *SI << '\n');
State.deleteDeadInstruction(NI);
++NumFastStores;
MadeChange = true;
} else
Current = NextDef;
}
}
}

Expand Down
Expand Up @@ -18,7 +18,6 @@ define void @test12(i32* %p) personality i32 (...)* @__CxxFrameHandler3 {
; CHECK-NEXT: invoke void @f()
; CHECK-NEXT: to label [[BLOCK3:%.*]] unwind label [[CATCH_DISPATCH:%.*]]
; CHECK: block3:
; CHECK-NEXT: store i32 30, i32* [[SV]]
; CHECK-NEXT: br label [[EXIT:%.*]]
; CHECK: catch.dispatch:
; CHECK-NEXT: [[CS1:%.*]] = catchswitch within none [label %catch] unwind label [[CLEANUP:%.*]]
Expand Down
39 changes: 36 additions & 3 deletions llvm/test/Transforms/DeadStoreElimination/MSSA/multiblock-loops.ll
Expand Up @@ -27,10 +27,9 @@ end:
define void @test14(i32* noalias %P) {
; CHECK-LABEL: @test14(
; CHECK-NEXT: entry:
; CHECK-NEXT: store i32 1, i32* [[P:%.*]]
; CHECK-NEXT: br label [[FOR:%.*]]
; CHECK: for:
; CHECK-NEXT: store i32 0, i32* [[P]]
; CHECK-NEXT: store i32 0, i32* [[P:%.*]]
; CHECK-NEXT: br i1 false, label [[FOR]], label [[END:%.*]]
; CHECK: end:
; CHECK-NEXT: ret void
Expand Down Expand Up @@ -77,7 +76,8 @@ define void @test21(i32* noalias %P) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[ARRAYIDX0:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 1
; CHECK-NEXT: [[P3:%.*]] = bitcast i32* [[ARRAYIDX0]] to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[P3]], i8 0, i64 28, i1 false)
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[P3]], i64 4
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[TMP0]], i8 0, i64 24, i1 false)
; CHECK-NEXT: br label [[FOR:%.*]]
; CHECK: for:
; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i32, i32* [[P]], i64 1
Expand Down Expand Up @@ -281,3 +281,36 @@ end:
ret void
}

%struct.hoge = type { i32, i32 }

@global = external local_unnamed_addr global %struct.hoge*, align 8

define void @widget(i8* %tmp) {
; CHECK-LABEL: @widget(
; CHECK-NEXT: bb:
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 [[TMP:%.*]], i8* nonnull align 16 undef, i64 64, i1 false)
; CHECK-NEXT: br label [[BB1:%.*]]
; CHECK: bb1:
; CHECK-NEXT: [[TMP2:%.*]] = load %struct.hoge*, %struct.hoge** @global, align 8
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[STRUCT_HOGE:%.*]], %struct.hoge* [[TMP2]], i64 undef, i32 1
; CHECK-NEXT: store i32 0, i32* [[TMP3]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = load %struct.hoge*, %struct.hoge** @global, align 8
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT_HOGE]], %struct.hoge* [[TMP4]], i64 undef, i32 1
; CHECK-NEXT: store i32 10, i32* [[TMP5]], align 4
; CHECK-NEXT: br label [[BB1]]
;
bb:
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %tmp, i8* nonnull align 16 undef, i64 64, i1 false)
br label %bb1

bb1: ; preds = %bb1, %bb
%tmp2 = load %struct.hoge*, %struct.hoge** @global, align 8
%tmp3 = getelementptr inbounds %struct.hoge, %struct.hoge* %tmp2, i64 undef, i32 1
store i32 0, i32* %tmp3, align 4
%tmp4 = load %struct.hoge*, %struct.hoge** @global, align 8
%tmp5 = getelementptr inbounds %struct.hoge, %struct.hoge* %tmp4, i64 undef, i32 1
store i32 10, i32* %tmp5, align 4
br label %bb1
}

declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg)

0 comments on commit 3a8372e

Please sign in to comment.