In [1]:
from UNetKerasMod import UNet
import tensorflow as tf
import copy
import itertools

# Generalized Model Info for scalability
# { id : {'prev_layer' : [str], 'output_shape': [int], 'next_layer' : [str] }}
# { 'model_order': [str]}
def get_model_info(model):
    model_info = {}
    graph = model.get_config()['layers']
    for layer_info in graph[::-1]:
        layer_name = layer_info['name']
        input_layers = layer_info['inbound_nodes']
        if len(input_layers):
            input_layers = [l[0] for l in input_layers[0]]
        else:
            input_layers = []

        # Fill the prev_layer
        if layer_name not in list(model_info.keys()):
            model_info[layer_name] = {}
        model_info[layer_name]['prev_layer'] = copy.deepcopy(input_layers)

        # Fill output_shape
        output_shape = list(model.get_layer(name=layer_name).output_shape)
        if len(output_shape) == 1:
            output_shape = [i for i in output_shape[0] if i]
        elif len(output_shape) > 1:
            output_shape = [i for i in output_shape if i]
        else:
            output_shape = []
        model_info[layer_name]['output_shape'] = copy.deepcopy(output_shape)

        # Fill next_layer of current input_layer
        for input_layer in input_layers:
            # Generate first time
            if input_layer not in list(model_info.keys()):
                model_info[input_layer] = {}
            if 'next_layer' not in list(model_info[input_layer].keys()):
                model_info[input_layer]['next_layer'] = []

            model_info[input_layer]['next_layer'].append(layer_name)

    model_info['model_order'] = copy.deepcopy([layer_info['name'] for layer_info in graph])

    return model_info

2022-08-11 13:02:13.931403: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
for i in range(0,5):
    if i == 0:
        model = UNet(3, 1, 16,[])
    else:
        model = UNet(3, 1, 16,[i])
    print("=================================")
    print(i)
    print("=================================")
    batch_size = 1
    img_size = (512,512)
    model.build(input_shape=(batch_size,img_size,3))

    # Get all node informations
    # get model info 
    model_info = get_model_info(model)

    model_order = model_info['model_order']
    candidates_points = list(range(len(model_order)))[1:-1]
    partition_combinations = itertools.combinations(candidates_points, 2)

    # range : [ start_node, end_node )
    start_node = 0
    end_node = len(model_order)
    for partition in partition_combinations:
        p1, p2 = partition        
        # Front
        front_layers = model_order[start_node:p1]
        # Middle
        middle_layers = model_order[p1:p2]
        # Back
        back_layers = model_order[p2:end_node]

        # We go simple - cutting sequential

        # Find front to middle data  
        # Naive method : traverse middle and check the dependency node is in front model 
        dep_set = set()
        for mlayer in middle_layers:
            prev_layers = model_info[mlayer]['prev_layer']
            for player in prev_layers:
                if player in front_layers:
                    dep_set.add(player)

        # INTERSECT (Middle.out, Back.in) 
        # Naive method : traverse back and check the dependency node is in middle model         
        dep_set = set()
        for blayer in back_layers:
            prev_layers = model_info[blayer]['prev_layer']
            for player in prev_layers:
                if player in middle_layers:
                    dep_set.add(player)

        print("{}-{}-{}-{}|".format(start_node, p1, p2, end_node), end='')
        total_data_size = 0
        for layer in dep_set:
            output_shape = model_info[layer]['output_shape']
            sz = 1
            for i in output_shape:
                sz = sz * i
            total_data_size += sz
        print(total_data_size)



0
0-1-2-67|4194304
0-1-3-67|4194304
0-1-4-67|4194304
0-1-5-67|4194304
0-1-6-67|4194304
0-1-7-67|5242880
0-1-8-67|6291456
0-1-9-67|6291456
0-1-10-67|6291456
0-1-11-67|6291456
0-1-12-67|6291456
0-1-13-67|6815744
0-1-14-67|7340032
0-1-15-67|7340032
0-1-16-67|7340032
0-1-17-67|7340032
0-1-18-67|7340032
0-1-19-67|7602176
0-1-20-67|7864320
0-1-21-67|7864320
0-1-22-67|7864320
0-1-23-67|7864320
0-1-24-67|7864320
0-1-25-67|7995392
0-1-26-67|8126464
0-1-27-67|8126464
0-1-28-67|8126464
0-1-29-67|8126464
0-1-30-67|8126464
0-1-31-67|8388608
0-1-32-67|8388608
0-1-33-67|8388608
0-1-34-67|8388608
0-1-35-67|7864320
0-1-36-67|7864320
0-1-37-67|7864320
0-1-38-67|7864320
0-1-39-67|7864320
0-1-40-67|8388608
0-1-41-67|8388608
0-1-42-67|8388608
0-1-43-67|8388608
0-1-44-67|7340032
0-1-45-67|7340032
0-1-46-67|7340032
0-1-47-67|7340032
0-1-48-67|7340032
0-1-49-67|8388608
0-1-50-67|8388608
0-1-51-67|8388608
0-1-52-67|8388608
0-1-53-67|6291456
0-1-54-67|6291456
0-1-55-67|6291456
0-1-56-67|6291456
0-1-57-67|629145

In [3]:
print(model_info)

{'conv2d_94': {'prev_layer': ['batch_normalization_109'], 'output_shape': [512, 512, 1]}, 'batch_normalization_109': {'next_layer': ['conv2d_94'], 'prev_layer': ['conv2d_93'], 'output_shape': [512, 512, 16]}, 'conv2d_93': {'next_layer': ['batch_normalization_109'], 'prev_layer': ['leaky_re_lu_69'], 'output_shape': [512, 512, 16]}, 'leaky_re_lu_69': {'next_layer': ['conv2d_93'], 'prev_layer': ['batch_normalization_108'], 'output_shape': [512, 512, 16]}, 'batch_normalization_108': {'next_layer': ['leaky_re_lu_69'], 'prev_layer': ['conv2d_92'], 'output_shape': [512, 512, 16]}, 'conv2d_92': {'next_layer': ['batch_normalization_108'], 'prev_layer': ['concatenate_19'], 'output_shape': [512, 512, 16]}, 'concatenate_19': {'next_layer': ['conv2d_92'], 'prev_layer': ['leaky_re_lu_68', 'batch_normalization_89'], 'output_shape': [512, 512, 32]}, 'leaky_re_lu_68': {'next_layer': ['concatenate_19'], 'prev_layer': ['batch_normalization_107'], 'output_shape': [512, 512, 16]}, 'batch_normalization_89':