In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchsummary import summary

In [2]:
# class to build a simple tree structure
class Node:
    def __init__(self, name):
        self.name = name
        self.children = []

    def add_child(self, child):
        self.children.append(child)

# function to get the node with a specific name
def get_node(root, name):
    if root.name == name:
        return root
    for child in root.children:
        result = get_node(child, name)
        if result:
            return result

# function to get a branch of the tree
def get_branch(node):
    if len(node.children) == 0:
        if 'softmax' in node.name:
            return [None]
        else:
            return [node.name]
        
    elif len(node.children) == 1:
        return [node.name] + get_branch(node.children[0])
    else:
        return [node.name] + node.children

In [3]:
def seq_model(model):
    
    # transforms a fully functional model to a model formed by sequential models

    def visualize_model(model):
        # traces the execution of the model
        tracer = torch.fx.Tracer()
        graph = tracer.trace(model)

        # obtains a summary of the model
        graph_info = []
        for node in graph.nodes:
            node_info = {
                'name': node.name,
                'op': node.op,
                'target': str(node.target),
                'args': [str(arg) for arg in node.args],
                'kwargs': node.kwargs,
                'users': node.users
            }
            graph_info.append(node_info)

        return graph_info
    
    graph_info = visualize_model(model)

    # initializes the root node
    root = Node(graph_info[0]['name'])
    visited = [graph_info[0]['name']]

    # builds the tree structure
    for layer in graph_info[1:]:
        parent = get_node(root, layer['args'][0])
        for arg in layer['args']:
            if arg in visited:
                visited.append(layer['name'])
                parent.add_child(Node(layer['name']))

    # function that creates a partial representation of the model from the tree
    def create_model(node):
        start = get_branch(node)

        model_partial = [nn.Sequential()]
        for layer in start:
            if layer == 'x':
                continue

            if not isinstance(layer, Node) and layer is not None and 'relu' not in layer and 'flatten' not in layer:
                if isinstance(model_partial[-1], torch.nn.Flatten):
                    model_partial.append(nn.Sequential())
                model_partial[-1].add_module(layer, getattr(model, layer))

            elif not isinstance(layer, Node) and layer is not None and 'flatten' in layer:
                model_partial.append(nn.Flatten())
                
            elif not isinstance(layer, Node) and layer is not None and 'relu' in layer:
                model_partial[-1].add_module('relu', nn.ReLU())

            elif layer is None:
                return model_partial
            
            elif isinstance(layer, Node):
                model_partial.append(create_model(layer))

        return model_partial

    list_to_gen = create_model(root)

    write_blocks = {}
    write_forwards = {}

    # function that counts elements recursively
    def count_elements(lista):
        total = 0
        for elemento in lista:
            if isinstance(elemento, list): 
                total += count_elements(elemento)
            else:
                total += 1  
        return total

    num_elements = list(range(count_elements(list_to_gen)))

    # functions that will create the code structure
    def write_block(list_to_gen):
        for block in list_to_gen:
            if isinstance(block, list):
                write_block(block)
            else:
                write_blocks[num_elements.pop(0)] = block

    write_block(list_to_gen)
    num_elements = list(range(count_elements(list_to_gen)))

    def write_forward(list_to_gen, parent=None):
        for block in list_to_gen:
            if isinstance(block, list):

                list_to_gen = list_to_gen[list_to_gen.index(block):]
                for following_list in list_to_gen:
                    write_forward(following_list, parent)
            
                break


            elif isinstance(block, nn.Flatten):
                
                write_forwards[num_elements[0]] = f'flatten_{parent}'
                parent = num_elements[0]
                num_elements.pop(0)

            else:
                
                if block == list_to_gen[-1]:
                    write_forwards[num_elements[0]] = f'output_sequential_{parent}'
                    parent = num_elements[0]
                    num_elements.pop(0)
                else:
                    write_forwards[num_elements[0]] = f'sequential_{parent}'
                    parent = num_elements[0]
                    num_elements.pop(0)


    write_forward(list_to_gen)

    # function that will generate the code
    def generate_model_code(blocks_dict, forward_dict, filename="generated_model.py"):
   
        with open(filename, "w") as f:

            f.write("import torch\n")
            f.write("import torch.nn as nn\n\n")
            f.write("class GeneratedModel(nn.Module):\n")
            f.write("    def __init__(self):\n")
            f.write("        super(GeneratedModel, self).__init__()\n\n")
            
            # block definitions
            for block_id, block in blocks_dict.items():
                if not isinstance(block, nn.Flatten):  
                    if isinstance(block, nn.Sequential):
                        layers = [
                            f"nn.{str(layer)}" 
                            for layer in block
                        ]
                        block_code = f"nn.Sequential(\n            " + ",\n            ".join(layers) + "\n        )"
                    else:
                        block_code = f"nn.{str(block)}"
                    
                    f.write(f"        self.block_{block_id} = {block_code}\n\n")
            

            f.write("    def forward(self, x):\n")
            
            outputs = [] 
            for key, value in forward_dict.items():
                parts = value.split("_")
                block_type = parts[0]
                connection = parts[-1] 
                

                if connection == "None":
                    input_var = "x"
                else:
                    input_var = f"out_{connection}"
                

                if isinstance(blocks_dict[key], nn.Flatten):
                    flatten_params = blocks_dict[key]
                    f.write(f"        out_{key} = {input_var}.flatten(start_dim={flatten_params.start_dim}, end_dim={flatten_params.end_dim})\n")
                else:
                    f.write(f"        out_{key} = self.block_{key}({input_var})\n")
                
                if block_type == "output":
                    outputs.append(f"out_{key}")
            
            if outputs:
                if len(outputs) > 1:
                    outputs_list = ", ".join(outputs)
                    f.write(f"        return {outputs_list}\n")
                else:
                    f.write(f"        return {outputs[0]}\n")
            else:
                f.write("        return None\n")
        
        print(f"Model generated and saved in {filename}")
    
    generate_model_code(write_blocks, write_forwards)
    

### Testing the code

In [None]:

class OneEE(nn.Module):
    def __init__(self):
        super(OneEE, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # early exit block
        self.obj_detect_conv = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.obj_detect_bn = nn.BatchNorm2d(128)
        self.obj_detect_fc_ee = nn.Linear(128 * 50 * 50, 4)  

        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(512)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.obj_detect_fc_final = nn.Linear(512 * 25 * 25, 4)

        self.binary_fc1 = nn.Linear(512 * 25 * 25, 256)
        self.binary_fc2 = nn.Linear(256, 2)

        self.regression_fc1 = nn.Linear(512 * 25 * 25, 256)
        self.regression_fc2 = nn.Linear(256, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)

        obj_detect_ee = F.relu(self.obj_detect_bn(self.obj_detect_conv(x)))
        obj_detect_ee = torch.flatten(obj_detect_ee, 1)  # Aplanar
        obj_detect_ee = F.softmax(self.obj_detect_fc_ee(obj_detect_ee), dim=1)

        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.pool3(x)
        x = torch.flatten(x, 1)

        obj_detect_final = F.softmax(self.obj_detect_fc_final(x), dim=1)

        binary_output = F.relu(self.binary_fc1(x))
        binary_output = F.softmax(self.binary_fc2(binary_output), dim=1)

        regression_output = F.relu(self.regression_fc1(x))
        regression_output = self.regression_fc2(regression_output)

        return obj_detect_ee, obj_detect_final, binary_output, regression_output




In [5]:
model = OneEE()
seq_model(model)

Model generated and saved in generated_model.py


Checking if the model is defined as expected:

In [6]:
from generated_model import GeneratedModel

model = GeneratedModel()
summary(model, (1, 200, 200))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 200, 200]             320
       BatchNorm2d-2         [-1, 32, 200, 200]              64
              ReLU-3         [-1, 32, 200, 200]               0
            Conv2d-4         [-1, 64, 200, 200]          18,496
       BatchNorm2d-5         [-1, 64, 200, 200]             128
         MaxPool2d-6         [-1, 64, 100, 100]               0
            Conv2d-7        [-1, 128, 100, 100]          73,856
       BatchNorm2d-8        [-1, 128, 100, 100]             256
            Conv2d-9        [-1, 256, 100, 100]         295,168
      BatchNorm2d-10        [-1, 256, 100, 100]             512
        MaxPool2d-11          [-1, 256, 50, 50]               0
           Conv2d-12          [-1, 128, 50, 50]         295,040
      BatchNorm2d-13          [-1, 128, 50, 50]             256
             ReLU-14          [-1, 128,

In [7]:
print(model)

GeneratedModel(
  (block_0): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block_1): Sequential(
    (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1):