Skip to content

Commit

Permalink
[TIR] Handle DeclBuffer in StorageRewrite (apache#15051)
Browse files Browse the repository at this point in the history
Allow `tir::DeclBuffer` to appear in the input of the `StorageRewrite`
transform.  Any `DeclBuffer` whose backing allocation is rewritten are
updated with the new buffer object.  Any `DeclBuffer` whose backing
allocation is unused has been removed by `StorageRewrite` is itself
removed.

This is a subset of changes, being split out from
apache#14778 into independent portions.
  • Loading branch information
Lunderberg authored and junrushao committed Jun 22, 2023
1 parent 918e4c8 commit f047d2f
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 19 deletions.
70 changes: 51 additions & 19 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
all_buffers_accessed_.insert(op->buffer.get());

// Add write access.
const VarNode* buf = op->buffer->data.get();
auto it = alloc_info_.find(buf);
const VarNode* buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buf);
scope_[it->second.level].touched.push_back(buffer_var);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
Expand All @@ -128,11 +130,14 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
void VisitExpr_(const BufferLoadNode* op) final {
// Add write access.
StmtExprVisitor::VisitExpr_(op);
const VarNode* buf = op->buffer->data.get();
auto it = alloc_info_.find(buf);

all_buffers_accessed_.insert(op->buffer.get());

const VarNode* buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buf);
scope_[it->second.level].touched.push_back(buffer_var);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
Expand Down Expand Up @@ -213,6 +218,9 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
// A record of which Buffer objects have been accessed, to prune
// unused DeclBuffer instances.
std::unordered_set<const BufferNode*> all_buffers_accessed_;

private:
// Whether already in thread env.
Expand Down Expand Up @@ -378,6 +386,7 @@ class StoragePlanRewriter : public StmtExprMutator {
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
stmt = operator()(std::move(stmt));
Expand Down Expand Up @@ -505,6 +514,20 @@ class StoragePlanRewriter : public StmtExprMutator {

Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); }

Stmt VisitStmt_(const DeclBufferNode* op) final {
if (hoisted_buffer_decls_.count(op->buffer.get()) ||
!all_buffers_accessed_.count(op->buffer.get())) {
return this->VisitStmt(op->body);
}
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

if (auto it = alloc_map_.find(op->buffer->data.get()); it != alloc_map_.end()) {
Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var);
node.CopyOnWrite()->buffer = buf;
}
return std::move(node);
}

private:
struct StorageEntry {
// The scope that this alloc attaches after
Expand All @@ -523,8 +546,9 @@ class StoragePlanRewriter : public StmtExprMutator {
std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry*> merged_children;
// The replacement allocation, if any.
Stmt new_alloc;
// The replacement Allocate, if any. May also include associated
// DeclBuffer statement.
std::vector<Stmt> alloc_nest;
// The var expr of new allocation.
Var alloc_var;
// The allocation element type.
Expand Down Expand Up @@ -560,13 +584,10 @@ class StoragePlanRewriter : public StmtExprMutator {
};

Stmt MakeAttach(const std::vector<StorageEntry*>& svec, Stmt body) {
std::vector<Stmt> nest;
for (StorageEntry* e : svec) {
if (e->new_alloc.defined()) {
nest.push_back(e->new_alloc);
}
for (auto it = svec.rbegin(); it != svec.rend(); it++) {
body = MergeNest((*it)->alloc_nest, body);
}
return MergeNest(nest, body);
return body;
}
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) {
Expand Down Expand Up @@ -636,8 +657,13 @@ class StoragePlanRewriter : public StmtExprMutator {

if (all_allocs_identical) {
// simply use the original allocation.
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0)));
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
e->alloc_nest.push_back(
DeclBuffer(RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
hoisted_buffer_decls_.insert(ptr->buffer.get());
}
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
Expand Down Expand Up @@ -684,8 +710,8 @@ class StoragePlanRewriter : public StmtExprMutator {
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
e->new_alloc =
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0));
e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
Expand Down Expand Up @@ -729,7 +755,8 @@ class StoragePlanRewriter : public StmtExprMutator {
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size =
make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits);
e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0));
e->alloc_nest.push_back(
Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)));
if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
Expand Down Expand Up @@ -996,6 +1023,11 @@ class StoragePlanRewriter : public StmtExprMutator {
std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
// The buffer objects being remapped
std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
// Buffers whose DeclBuffer has been hoisted to be adjacent to the new Allocate location
std::unordered_set<const BufferNode*> hoisted_buffer_decls_;
// Any buffers that is accessed at some point. DeclBuffer instances
// that do not appear in this list may be removed.
std::unordered_set<const BufferNode*> all_buffers_accessed_;
// analyzer
arith::Analyzer analyzer_;
};
Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,5 +860,71 @@ def before(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")):
expected = before


class TestRewriteDeclBuffer(BaseCompare):
"""A DeclBuffer node may appear in StorageRewrite's input"""

def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32")

for i in range(16):
B[i] = A[i]

for i in range(16):
C[i] = 2.0 * B[i]

for i in range(16):
D[i] = C[i]

def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32", data=B.data)

for i in range(16):
B[i] = A[i]

for i in range(16):
C[i] = 2.0 * B[i]

for i in range(16):
D[i] = C[i]


class TestNoOrphanedDeclBuffer(BaseCompare):
"""A DeclBuffer of an unused Allocate should be removed
StorageRewrite removes any allocations that are unused. When it
does so, any DeclBuffer that refers to that allocation should also
be removed.
"""

def before(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32")
Unused = T.decl_buffer(16, dtype="float32")

for i in range(16):
B[i] = A[i]

for i in range(16):
C[i] = 2.0 * B[i]

for i in range(16):
D[i] = C[i]

def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
B = T.decl_buffer(16, dtype="float32")
C = T.decl_buffer(16, dtype="float32", data=B.data)

for i in range(16):
B[i] = A[i]

for i in range(16):
C[i] = 2.0 * B[i]

for i in range(16):
D[i] = C[i]


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f047d2f

Please sign in to comment.