Skip to content

Commit

Permalink
[Unity] Support storage reuse for dynamic shapes (apache#16500)
Browse files Browse the repository at this point in the history
Before this PR, dynamic shapes require upper bound of
variables to be provided in order to use storage planning.
We can relax this requirement, for shapes with unknown bound,
we can look up other tensors with the same symbolic
shapes. This can be helpful for deep learning models
where the layers with the same configurations are usually
repeated since there are many objects with the same shapes.

This PR changed the `StorageToken` to use `PrimExpr`
bytes which can be integer or symbolic. For symbolic
shapes, we put the tokens into a special buckets for looking up.
  • Loading branch information
vinx13 committed Feb 16, 2024
1 parent dbdb736 commit 63a09cb
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 56 deletions.
99 changes: 69 additions & 30 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,23 @@ class StorageTokenNode : public Object {
/*! \brief Reference counter. */
int ref_counter{0};
/*! \brief Number of bytes that this token requires. */
int64_t bytes;
PrimExpr bytes;
/*! \brief The dtype of this token. */
DataType dtype;
/*! \brief The storage id, reserved for debug and demo use. */
int storage_id{-1};

/*! \brief Get the constant number of bytes that this token requires, or -1 if the number of bytes
* is symbolic */
int64_t const_bytes() const {
const int64_t* const_val = tir::as_const_int(bytes);
if (const_val) {
return *const_val;
} else {
return -1;
}
}

static constexpr const char* _type_key = "relax.transform.StorageToken";
TVM_DECLARE_BASE_OBJECT_INFO(StorageTokenNode, Object);
};
Expand All @@ -117,19 +128,22 @@ class StorageToken : public ObjectRef {
public:
explicit StorageToken(Array<PrimExpr> shape, DataType dtype) {
// Compute the tensor size from the shape.
int64_t size = 1;
int64_t const_coeff = dtype.bytes() * dtype.lanes();
PrimExpr size = tir::make_const(DataType::Int(64), 1);
for (const PrimExpr& dim_len : shape) {
const auto* int_len = dim_len.as<IntImmNode>();
ICHECK_NOTNULL(int_len);
size *= int_len->value;
if (const IntImmNode* const_dim_len = dim_len.as<IntImmNode>()) {
const_coeff *= const_dim_len->value;
} else {
size *= dim_len;
}
}
size = tir::make_const(DataType::Int(64), const_coeff) * size;

ObjectPtr<StorageTokenNode> n = make_object<StorageTokenNode>();
n->bytes = size * dtype.bytes() * dtype.lanes();
n->bytes = size;
n->dtype = dtype;
data_ = std::move(n);
}

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(StorageToken, ObjectRef, StorageTokenNode);
};

Expand All @@ -143,6 +157,8 @@ using Tokens = NestedMsg<StorageToken>;
*/
class TokenAllocator1D {
public:
explicit TokenAllocator1D(arith::Analyzer* analyzer) : analyzer_(analyzer) {}

/*!
* \brief Request a storage token from the available token pool for a
* given prototype, or report no appropriate available token in the pool.
Expand All @@ -162,8 +178,24 @@ class TokenAllocator1D {
// Step 1. Get the available pool of the token dtype.
std::multimap<int64_t, StorageToken>& pool = available_pool_[prototype->dtype];

int64_t size = prototype->const_bytes();
if (size == -1) {
// Handle the case where the prototype token has dynamic size. Currently it requires the
// symbolic size to be the same as the prototype token in order to reuse the storage.
auto [begin, end] = pool.equal_range(size);
for (; begin != end; ++begin) {
StorageToken available_token = begin->second;
if (analyzer_->CanProveEqual(prototype->bytes, available_token->bytes)) {
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
available_token->ref_counter = prototype->ref_counter;
pool.erase(begin);
return available_token;
}
}
return NullOpt;
}
// Step 2. Get the range of memory blocks in [size / match_range_, size * match_range_)
int64_t size = prototype->bytes;
auto begin = pool.lower_bound(size / match_range_);
auto mid = pool.lower_bound(size);
auto end = pool.upper_bound(size * match_range_);
Expand All @@ -172,7 +204,7 @@ class TokenAllocator1D {
StorageToken available_token = mid->second;
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
ICHECK_LE(size, available_token->bytes);
ICHECK_LE(size, available_token->const_bytes());
available_token->ref_counter = prototype->ref_counter;
pool.erase(mid);
return available_token;
Expand All @@ -181,11 +213,13 @@ class TokenAllocator1D {
if (mid != begin) {
--mid;
StorageToken available_token = mid->second;
int64_t available_size = available_token->const_bytes();
ICHECK_EQ(available_token->ref_counter, 0)
<< "Available tokens are expected to have 0 reference.";
ICHECK_GE(size, available_token->bytes);
ICHECK_GE(available_size, 0);
ICHECK_GE(size, available_size);
// Enlarge the token size.
available_token->bytes = size;
available_token->bytes = tir::make_const(DataType::Int(64), size);
available_token->ref_counter = prototype->ref_counter;
pool.erase(mid);
return available_token;
Expand Down Expand Up @@ -216,7 +250,7 @@ class TokenAllocator1D {
ICHECK_GE(token->storage_id, 0)
<< "The token to be released is expected to be allocated before";
ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected to have 0 reference.";
available_pool_[token->dtype].insert({token->bytes, token});
available_pool_[token->dtype].insert({token->const_bytes(), token});
}

/*! \brief Clear the allocator. */
Expand All @@ -226,6 +260,8 @@ class TokenAllocator1D {
}

private:
/*! \brief The arithmetic analyzer. */
arith::Analyzer* analyzer_;
/*! \brief A constant scale representing the token search range. */
const int match_range_{16};
/*! \brief The pool of available storage tokens for each dtype. */
Expand Down Expand Up @@ -385,10 +421,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
/*!
* \brief The entry of the initialization.
* \param mod The IRModule to be planned
* \param analyzer The arithmetic analyzer.
* \return The mapping from each Expr to the token it uses.
*/
static std::unordered_map<const ExprNode*, Tokens> Initialize(const IRModule& mod) {
StorageAllocatorInit initializer(mod);
static std::unordered_map<const ExprNode*, Tokens> Initialize(const IRModule& mod,
arith::Analyzer* analyzer) {
StorageAllocatorInit initializer(mod, analyzer);

for (auto it : mod->functions) {
const auto* func = it.second.as<FunctionNode>();
Expand All @@ -403,11 +441,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
private:
using ExprVisitor::VisitExpr_;

explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
explicit StorageAllocatorInit(const IRModule& ctx_mod, arith::Analyzer* analyzer)
: ctx_mod_(ctx_mod), analyzer_(analyzer) {}

void VisitExpr_(const FunctionNode* func) final {
// Set the upper bound of TIR variables in the analyzer.
SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
SetTIRVarUpperBound(GetRef<Function>(func), analyzer_);
// Recurse into the function to get its tokens.
Tokens body_tokens = GetTokens(func->body);
// Discard the tokens used by the function return value, as they are external referenced.
Expand Down Expand Up @@ -508,14 +547,9 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
ICHECK(!token_map_.count(call));

// Use the upper bounds of TIR vars as their values.
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);

// No support for TIR vars that are not bounded.
if (!IsStaticShape(upper_bounded_shape)) {
token_map_[call] = Tokens();
return Tokens();
}
// Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic
// if the upper bounds of some variables are not provided.
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, analyzer_);

// Create and set token.
StorageToken token(upper_bounded_shape, sinfo->dtype);
Expand Down Expand Up @@ -583,13 +617,13 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
token2block_.erase(token_to_discard.get());
}

/*! \brief The arithmetic analyzer. */
arith::Analyzer ana_;
/*!
* \brief The context IRModule, used for checking if a callee function is
* a PrimFunc inside the IRModule.
*/
const IRModule& ctx_mod_;
/*! \brief The arithmetic analyzer. */
arith::Analyzer* analyzer_;
/*! \brief The mapping from each token to the binding block where it is created. */
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
/*! \brief The mapping from each token to the Exprs that are using this token. */
Expand All @@ -612,7 +646,9 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
*/
class StorageAllocator : public StorageAllocatorBaseVisitor {
public:
explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> token_map) {
explicit StorageAllocator(std::unordered_map<const ExprNode*, Tokens> token_map,
arith::Analyzer* analyzer)
: allocator_(analyzer) {
this->token_map_ = std::move(token_map);
}

Expand Down Expand Up @@ -797,7 +833,7 @@ class StorageAllocationRewriter : public ExprMutator {
Var storage_var{nullptr};
auto it_token = token2storage_var_.find(token.get());
if (it_token == token2storage_var_.end()) {
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
ShapeExpr size({token->bytes});
PrimValue virtual_device_index = runtime_device_index;
std::string storage_scope = "global";
DataType dtype = token->dtype;
Expand Down Expand Up @@ -868,10 +904,13 @@ class StorageAllocationRewriter : public ExprMutator {
};

IRModule StaticPlanBlockMemory(IRModule mod) {
arith::Analyzer ana;

// Step 1. Initialize.
std::unordered_map<const ExprNode*, Tokens> token_map = StorageAllocatorInit::Initialize(mod);
std::unordered_map<const ExprNode*, Tokens> token_map =
StorageAllocatorInit::Initialize(mod, &ana);
// Step 2. Collect the memory allocation info.
StorageAllocator allocator(std::move(token_map));
StorageAllocator allocator(std::move(token_map), &ana);
allocator.Allocate(mod);
// Step 3. Rewrite the function.
StorageAllocationRewriter rewriter(std::move(mod), //
Expand Down
35 changes: 17 additions & 18 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,29 +1226,27 @@ def expected(
lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
lv3: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 0, 1280), dtype="float32"),
) = R.split(lv2, indices_or_sections=[640, 1280], axis=1)
lv0: R.Tensor((2, 640, 1280), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3[1]
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2, indices_or_sections=[640], axis=-1)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
lv_1: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
lv1_2: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv_1, axes=None)
lv2_1: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(
x2, lv1_2, out_dtype="void"
)
lv3_1: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 0, 1280), dtype="float32"),
) = R.split(lv2_1, indices_or_sections=[640, 1280], axis=1)
lv2_1_1: R.Tensor((2, 640, 1280), dtype="float32") = lv3_1[0]
lv3_1_1: R.Tensor((2, 384, 1280), dtype="float32") = lv3_1[1]
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = R.split(lv2_1, indices_or_sections=[640], axis=-1)
lv2_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[0]
lv3_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[1]
out: R.Tuple(
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 640, 1280), dtype="float32"),
R.Tensor((2, 384, 1280), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
R.Tensor((2, 1024, 640), dtype="float32"),
) = (lv0, lv1_1, lv2_1_1, lv3_1_1)
R.output(out)
return out
Expand All @@ -1267,9 +1265,9 @@ def rewriter(matchings, _):

concat = R.concat([w1, w2], axis=0)
matmul = R.matmul(inp, R.permute_dims(concat))
sections = [w1.struct_info.shape[0], w1.struct_info.shape[0] + w2.struct_info.shape[0]]
sections = [w1.struct_info.shape[0]]

chunks = R.split(matmul, sections, 1)
chunks = R.split(matmul, sections, -1)

return {
matchings[matmul1]: chunks[0],
Expand All @@ -1282,6 +1280,7 @@ def rewriter(matchings, _):
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
print(mod)

rx.build(mod, target="llvm")

Expand Down
47 changes: 39 additions & 8 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,34 @@ def main(x: R.Tensor(("m", "n"), "float32")):
y: R.Tensor((m, n), dtype="float32") = alloc
return x

# The pass does no change.
@tvm.script.ir_module
class Expected:
@T.prim_func
def exp(var_A: T.handle, var_B: T.handle):
m = T.int64()
n = T.int64()
A = T.match_buffer(var_A, (m, n), "float32")
B = T.match_buffer(var_B, (m, n), "float32")
T.evaluate(0)

@R.function
def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype="float32"):
m = T.int64()
n = T.int64()
R.func_attr({"relax.force_pure": True})
cls = Expected
storage: R.Object = R.memory.alloc_storage(
R.shape([4 * (m * n)]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
alloc: R.Tensor((m, n), dtype="float32") = R.memory.alloc_tensor(
storage, R.prim_value(0), R.shape([m, n]), R.dtype("float32")
)
_: R.Tuple = cls.exp(x, alloc)
y: R.Tensor((m, n), dtype="float32") = alloc
return x

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_zero_reference():
Expand Down Expand Up @@ -1198,7 +1223,10 @@ def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
alloc2: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3
alloc3: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_3: R.Tuple = cls.tir_exp(lv3, alloc3)
lv4: R.Tensor((n, m), dtype="float32") = alloc3
return lv4

@I.ir_module
class Expected:
Expand All @@ -1216,19 +1244,22 @@ def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
m = T.int64()
R.func_attr({"relax.force_pure": True, "tir_var_upper_bound": {"n": 20}})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
storage: R.Object = R.memory.alloc_storage(R.shape([80 * m]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
full: R.Tensor((n, m), dtype="float32") = alloc
storage1: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
storage1: R.Object = R.memory.alloc_storage(R.shape([80 * m]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n, m), dtype="float32") = alloc1
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc3: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_3: R.Tuple = cls.tir_exp(lv3, alloc3)
lv4 = alloc3
return lv4
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
Expand Down

0 comments on commit 63a09cb

Please sign in to comment.