Skip to content

Commit

Permalink
Respect dataflow deps in dead code elimination (plaidml#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
earhart authored and mergify[bot] committed Jan 6, 2020
1 parent 7fe7fb2 commit 405f82b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 41 deletions.
40 changes: 19 additions & 21 deletions tile/codegen/dce.cc
Expand Up @@ -183,7 +183,7 @@ void DeadCodeElimination(const AliasMap& alias_map, Block* block) {
}
}

ComputeDepsForBlock(block, alias_map);
ComputeDataflowDepsForBlock(block, alias_map);
// Map a statement to its uses
std::map<Statement*, std::vector<Statement*>> stmt_uses;
for (const auto& stmt : block->stmts) {
Expand Down Expand Up @@ -231,27 +231,25 @@ void DeadCodeElimination(const AliasMap& alias_map, Block* block) {

void DeadCodeEliminationPass::Apply(CompilerState* state) const {
auto reqs = stripe::FromProto(options_.reqs());
RunOnBlocksBackward(
state->entry(), reqs,
[](const AliasMap& alias_map, stripe::Block* block) { //
DeadCodeElimination(alias_map, block);
},
true);
RunOnBlocksBackward(state->entry(), reqs,
[](const AliasMap& alias_map, stripe::Block* block) { //
DeadCodeElimination(alias_map, block);
},
true);

RunOnBlocks(
state->entry(), reqs,
[&](const AliasMap& map, stripe::Block* block) { //
if (options_.fix_deps()) {
// Rebuild deps
ComputeDepsForBlock(block, map);
} else {
// Clean up deps after use
for (auto& stmt : block->stmts) {
stmt.get()->deps.clear();
}
}
},
true);
RunOnBlocks(state->entry(), reqs,
[&](const AliasMap& map, stripe::Block* block) { //
if (options_.fix_deps()) {
// Rebuild deps
ComputeDepsForBlock(block, map);
} else {
// Clean up deps after use
for (auto& stmt : block->stmts) {
stmt.get()->deps.clear();
}
}
},
true);
}

namespace {
Expand Down
57 changes: 37 additions & 20 deletions tile/codegen/deps.cc
Expand Up @@ -2,9 +2,12 @@

#include "tile/codegen/deps.h"

#include <algorithm>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <boost/format.hpp>

Expand Down Expand Up @@ -117,58 +120,51 @@ struct Tracker {

buffer_info.readers.insert(it);
}
};

} // namespace

void ComputeDepsForBlock(Block* block, const AliasMap& alias_map) {
IVLOG(3, "ComputeDeps> " << block->name);
Tracker tracker;
std::unordered_map<StatementIt, std::set<StatementIt>> transitive_deps;
for (auto it = block->stmts.begin(); it != block->stmts.end(); it++) {
void ApplyEffectsOf(StatementIt it, Block* block, const AliasMap& alias_map) {
// Adjust the current scalar and buffer tracking structures and update dataflow_deps.
switch ((*it)->kind()) {
case StmtKind::Load: {
auto load = Load::Downcast(*it);
IVLOG(3, " load: " << load);
tracker.ReadBuffer(it, load->from, alias_map);
tracker.WriteScalar(*block, it, load->into);
ReadBuffer(it, load->from, alias_map);
WriteScalar(*block, it, load->into);
} break;
case StmtKind::Store: {
auto store = Store::Downcast(*it);
IVLOG(3, " store: " << store);
tracker.ReadScalar(*block, store->from);
tracker.WriteBuffer(it, store->into, alias_map);
ReadScalar(*block, store->from);
WriteBuffer(it, store->into, alias_map);
} break;
case StmtKind::LoadIndex: {
auto load_index = LoadIndex::Downcast(*it);
IVLOG(3, " loadIndex: " << load_index);
tracker.WriteScalar(*block, it, load_index->into);
WriteScalar(*block, it, load_index->into);
} break;
case StmtKind::Special: {
auto special = Special::Downcast(*it);
IVLOG(3, " special: " << special);
for (const auto& in : special->inputs) {
tracker.ReadBuffer(it, in, alias_map);
ReadBuffer(it, in, alias_map);
}
for (const auto& out : special->outputs) {
tracker.WriteBuffer(it, out, alias_map);
WriteBuffer(it, out, alias_map);
}
} break;
case StmtKind::Intrinsic: {
auto intrinsic = Intrinsic::Downcast(*it);
IVLOG(3, " intrinsic: " << intrinsic);
for (const auto& in : intrinsic->inputs) {
tracker.ReadScalar(*block, in);
ReadScalar(*block, in);
}
for (const auto& out : intrinsic->outputs) {
tracker.WriteScalar(*block, it, out);
WriteScalar(*block, it, out);
}
} break;
case StmtKind::Constant: {
auto constant = Constant::Downcast(*it);
IVLOG(3, " constant: " << constant);
tracker.WriteScalar(*block, it, constant->name);
WriteScalar(*block, it, constant->name);
} break;
case StmtKind::Block: {
auto inner = Block::Downcast(*it);
Expand All @@ -180,14 +176,35 @@ void ComputeDepsForBlock(Block* block, const AliasMap& alias_map) {
// the same underlying physical buffer; we handle these cases in
// ReadBuffer() and WriteBuffer().
if (IsReadDir(ref.dir)) {
tracker.ReadBuffer(it, ref.into(), inner_map);
ReadBuffer(it, ref.into(), inner_map);
}
if (IsWriteDir(ref.dir)) {
tracker.WriteBuffer(it, ref.into(), inner_map);
WriteBuffer(it, ref.into(), inner_map);
}
}
} break;
}
}
};

} // namespace

void ComputeDataflowDepsForBlock(stripe::Block* block, const AliasMap& alias_map) {
Tracker tracker;
for (auto it = block->stmts.begin(); it != block->stmts.end(); it++) {
tracker.ApplyEffectsOf(it, block, alias_map);
(*it)->deps.clear();
std::copy(tracker.dataflow_deps.begin(), tracker.dataflow_deps.end(), std::back_inserter((*it)->deps));
tracker.dataflow_deps.clear();
}
}

void ComputeDepsForBlock(Block* block, const AliasMap& alias_map) {
IVLOG(3, "ComputeDeps> " << block->name);
Tracker tracker;
std::unordered_map<StatementIt, std::set<StatementIt>> transitive_deps;
for (auto it = block->stmts.begin(); it != block->stmts.end(); it++) {
tracker.ApplyEffectsOf(it, block, alias_map);

// At this point, dataflow_deps describes the dataflow dependencies of the current Statement.
// Use it to compute the Statement's transitive dependencies.
Expand Down
11 changes: 11 additions & 0 deletions tile/codegen/deps.h
Expand Up @@ -26,7 +26,18 @@ inline bool ZeroBlock(const std::shared_ptr<stripe::Statement>& stmt) {
return false;
}

// Recomputes Statement dataflow dependencies within a single Block.
//
// After this call, each statement X's dependencies will be the set of all statements that write to an input
// of X -- the dataflow dependencies. E.g. if A's dataflow dependencies are B and C, and B also depends on C,
// A's statement dependencies will be [B, C].
void ComputeDataflowDepsForBlock(stripe::Block* block, const AliasMap& alias_map);

// Recomputes Statement dependencies within a single Block.
//
// After this call, each statement X's dependencies will be a set of statements that must be completed in
// order for X's inputs to be ready. E.g. if A's dataflow dependencies are B and C, and B also depends on C,
// A's statement dependencies will be [B].
void ComputeDepsForBlock(stripe::Block* block, const AliasMap& alias_map);

class ComputeDepsPass final : public CompilePass {
Expand Down

0 comments on commit 405f82b

Please sign in to comment.