In [1]:
import pickle

import ortools.sat.python.cp_model as cp_model

In [2]:
with open("./simple_model_mem_alloc_data.pkl", "rb") as fid:
    mem_alloc_data = pickle.load(fid)

In [3]:
mem_alloc_data.keys()

dict_keys(['tensors_meta', 'op_inputs_map', 'op_outputs_map', 'ops_topo_order'])

In [7]:
tensors_meta = mem_alloc_data["tensors_meta"]
op_inputs_map = mem_alloc_data["op_inputs_map"]
op_outputs_map = mem_alloc_data["op_outputs_map"]
ops_topo_order = mem_alloc_data["ops_topo_order"]

In [5]:
from collections import Counter

In [6]:
tensor_ref_cnt = Counter()

In [11]:
op_inputs_map[ops_topo_order[1]]

['input_1:0']

In [17]:
for op_name in ops_topo_order:
    for tensor_name in op_inputs_map[op_name]:
        tensor_ref_cnt[tensor_name] += 1

In [13]:
tensor_ref_cnt

Counter({'input_1:0': 1,
         'input_1_int8:0': 1,
         'StatefulPartitionedCall/my_model_1/conv2d_1/Conv2D/ReadVariableOp:0': 1,
         'StatefulPartitionedCall/my_model_1/conv2d_1/Conv2D_bias:0': 1,
         'StatefulPartitionedCall/my_model_1/conv2d_1/Relu:0': 1,
         'StatefulPartitionedCall/my_model_1/max_pooling2d_1/MaxPool:0': 1,
         'StatefulPartitionedCall/my_model_1/max_pooling2d_1/MaxPool_0_Reshape0:0': 1,
         'StatefulPartitionedCall/my_model_1/dense_2/MatMul/ReadVariableOp/transpose:0': 1,
         'StatefulPartitionedCall/my_model_1/dense_2/MatMul_bias:0': 1,
         'StatefulPartitionedCall/my_model_1/dense_2/Relu:0': 1,
         'StatefulPartitionedCall/my_model_1/dense_3/MatMul/ReadVariableOp/transpose:0': 1,
         'StatefulPartitionedCall/my_model_1/dense_3/MatMul_bias:0': 1,
         'Identity_int8:0': 1})

In [24]:
from functools import reduce
from math import ceil
from collections import namedtuple

In [25]:
MemAlloc = namedtuple("MemAlloc", ["start", "end", "size"])

In [30]:
def compute_size(tensor_meta, alignment:int=None):
    size, dtype, itemsize = tensor_meta
    total_size = reduce(lambda a, b: a*b, size, 1)
    if alignment is not None and (total_size % alignment != 0):
        total_size = ceil(total_size / alignment) * alignment
    return total_size

In [61]:
naive_total = 0
for tensor_meta in tensors_meta.values():
    naive_total += compute_size(tensor_meta, 4)
print(naive_total)

728132


In [37]:
from itertools import combinations, product

In [59]:
max_pool_size_in_bytes = 1024*1024
visited_tensors = set()
allocation_plan = {}
var_intv_map = {}
no_overlap_pairs = set()

opt_model = cp_model.CpModel()
# Step 1: declare variables
for op_name in ops_topo_order:
    out_tensors = op_outputs_map[op_name]
    for tensor_name in out_tensors:
        var_start = opt_model.NewIntVar(0, max_pool_size_in_bytes, f'{tensor_name}_start')
        var_end = opt_model.NewIntVar(0, max_pool_size_in_bytes, f'{tensor_name}_end')
        tensor_meta = tensors_meta[tensor_name]
        tensor_size_in_bytes = compute_size(tensor_meta, 4) # 32-bits alignment
        var_intv = opt_model.NewIntervalVar(var_start, tensor_size_in_bytes, var_end, f'{tensor_name}_alloc')
        allocation_plan[tensor_name] = MemAlloc(start=var_start, end=var_end, size=tensor_size_in_bytes)
        var_intv_map[tensor_name] = var_intv
# Step 2: Setup constraints
for op_name in ops_topo_order:
    # outputs of the op should not overlap
    output_tensor_names = op_outputs_map[op_name]
    input_tensor_names = op_inputs_map[op_name]
    for this_tensor_name, that_tensor_name in combinations(output_tensor_names, 2):
        this_intv = var_intv_map[this_tensor_name]
        that_intv = var_intv_map[that_tensor_name]
        opt_model.AddNoOverlap([this_intv, that_intv])
        no_overlap_pairs.add((this_tensor_name, that_tensor_name))
    # inputs and outputs should not overlap
    for out_tensor_name, in_tensor_name in product(output_tensor_names, input_tensor_names):
        out_intv = var_intv_map[out_tensor_name]
        in_intv = var_intv_map[in_tensor_name]
        opt_model.AddNoOverlap([out_intv, in_intv])
        no_overlap_pairs.add((out_tensor_name, in_tensor_name))
    # the outputs of the op should not overlap with exisiting tensors with positive ref cnt
    for out_tensor_name, visit_tensor_name in product(output_tensor_names, visited_tensors):
        if tensor_ref_cnt[visit_tensor_name] <= 0:
            continue
        out_intv = var_intv_map[out_tensor_name]
        visited_intv = var_intv_map[visit_tensor_name]
        opt_model.AddNoOverlap([out_intv, visited_intv])
        no_overlap_pairs.add((out_intv, visited_intv))

    # decr ref cnt
    for tensor_name in input_tensor_names:
        tensor_ref_cnt[tensor_name] -= 1
    # update visited tensors
    visited_tensors.update(input_tensor_names)
    visited_tensors.update(output_tensor_names)
# Step 3: setup objective
var_mem_pool_size = opt_model.NewIntVar(0, max_pool_size_in_bytes, 'mem_pool_size')
opt_model.AddMaxEquality(var_mem_pool_size, [alloc.end for alloc in allocation_plan.values()])
opt_model.Minimize(var_mem_pool_size)
# Step 4: Solve
solver = cp_model.CpSolver()
status = solver.Solve(opt_model)

In [62]:
if status == cp_model.OPTIMAL:
    opt_mem_pool_size = solver.Value(var_mem_pool_size)
    print(f"optimal allocation plan found: {opt_mem_pool_size} total bytes")
    print(f"opt/naive: {opt_mem_pool_size/naive_total*100:0.3f}")
    for tensor_name, mem_alloc in allocation_plan.items():
        mem_start = solver.Value(mem_alloc.start)
        mem_end = solver.Value(mem_alloc.end)
        print(f"{tensor_name}: ({mem_start}, {mem_end})")
else:
    print("fail to find optimal allocation plan")

optimal allocation plan found: 692352 total bytes
opt/naive: 95.086
input_1:0: (0, 784)
input_1_int8:0: (21632, 22416)
StatefulPartitionedCall/my_model_1/conv2d_1/Conv2D/ReadVariableOp:0: (21632, 21920)
StatefulPartitionedCall/my_model_1/conv2d_1/Conv2D_bias:0: (21632, 21664)
StatefulPartitionedCall/my_model_1/conv2d_1/Relu:0: (0, 21632)
StatefulPartitionedCall/my_model_1/max_pooling2d_1/MaxPool:0: (21632, 27040)
StatefulPartitionedCall/my_model_1/max_pooling2d_1/MaxPool_0_Reshape0:0: (0, 5408)
StatefulPartitionedCall/my_model_1/dense_2/MatMul/ReadVariableOp/transpose:0: (0, 692224)
StatefulPartitionedCall/my_model_1/dense_2/MatMul_bias:0: (0, 128)
StatefulPartitionedCall/my_model_1/dense_2/Relu:0: (692224, 692352)
StatefulPartitionedCall/my_model_1/dense_3/MatMul/ReadVariableOp/transpose:0: (0, 1280)
StatefulPartitionedCall/my_model_1/dense_3/MatMul_bias:0: (0, 12)
Identity_int8:0: (1280, 1292)
Identity:0: (0, 12)
