In [5]:
import torch
from torch import nn
from simple_model import simple_model
from torchsummary import summary
import torch.nn.functional as F

In [None]:
def summary_to_dict(model, input_size, batch_size=1, device="cpu"):
    # function to convert the summary output to a dictionary
    model.to(device)
    x = torch.rand(batch_size, *input_size).to(device)

    summary_dict = {}
    layer_counts = {}

    def register_hook(module):
        def hook(module, input, output):
            module_name = module.__class__.__name__
            # Contar capas con el mismo nombre
            if module_name not in layer_counts:
                layer_counts[module_name] = 0
            layer_counts[module_name] += 1
            name = f"{module_name}_{layer_counts[module_name]}"

            # Verificar si la salida es un tensor o una tupla
            if isinstance(output, tuple):
                output_shapes = [list(o.shape) for o in output]
            else:
                output_shapes = list(output.shape)

            # Guardar el tamaño de salida y los parámetros en el diccionario
            summary_dict[name] = {
                "output_shape": output_shapes,
                "params": sum(p.numel() for p in module.parameters() if p.requires_grad)
            }

        # Registrar el hook en todas las capas, excluyendo contenedores
        if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential):
            module.register_forward_hook(hook)

    model.apply(register_hook)
    model(x)
    
    return summary_dict


In [7]:
model = simple_model()
input_size = (1, 200, 200)
summary_model = summary_to_dict(model, input_size)

In [8]:
summary_model

{'Conv2d_1': {'output_shape': [1, 16, 100, 100], 'params': 416},
 'ReLU_1': {'output_shape': [1, 16, 100, 100], 'params': 0},
 'Conv2d_2': {'output_shape': [1, 32, 50, 50], 'params': 4640},
 'ReLU_2': {'output_shape': [1, 32, 50, 50], 'params': 0},
 'Linear_1': {'output_shape': [1, 4], 'params': 320004},
 'Conv2d_3': {'output_shape': [1, 64, 25, 25], 'params': 18496},
 'ReLU_3': {'output_shape': [1, 64, 25, 25], 'params': 0},
 'Conv2d_4': {'output_shape': [1, 64, 13, 13], 'params': 36928},
 'ReLU_4': {'output_shape': [1, 64, 13, 13], 'params': 0},
 'Linear_2': {'output_shape': [1, 4], 'params': 43268},
 'simple_model_1': {'output_shape': [[1, 4], [1, 4]], 'params': 423752}}

In [9]:
def get_context_vectors(model, summary_dict, input_size, n_classes_ee):
    context_vectors = []
    previous_shape = input_size # previous shape of layer considered starts with input shape
    model = [item for j in model.children() for item in j.children()] # serialize the model
    intermediate_size = input_size[0] * input_size[1] * input_size[2] # assuming input is an image

    for p in range(len(summary_dict)-1): # last one is the total model, so we don't need it
        macs_conv, macs_lin, macs_act, n_conv, n_lin, n_act = 0, 0, 0, 0, 0, 0
        model_considered = {k: v for k, v in summary_dict.items() if k in list(summary_dict.keys())[p:]} # get the layers from the current layer to the end    
        
        for ix, layer in enumerate(model_considered):

            if 'conv' in layer.lower():
                n_conv += 1
                output_shape = tuple(summary_dict[layer]['output_shape'][1:]) # remove batch size
                macs_conv += output_shape[0] * output_shape[1] * output_shape[2] * model[ix].kernel_size[0] * model[ix].kernel_size[1] * previous_shape[0]
                # so output channels * output height * output width * kernel height * kernel width * kernel depth
                previous_shape = output_shape

            elif 'linear' in layer.lower():
                n_lin += 1
                output_shape = tuple(summary_dict[layer]['output_shape'][1:])
                macs_lin += model[ix].in_features * model[ix].out_features
                previous_shape = output_shape

            elif 'relu' in layer.lower():
                n_act += 1
                output_shape = summary_dict[layer]['output_shape'][1:]
                
                total_act = 1
                for dim in output_shape:
                    total_act *= dim
                macs_act += total_act
                previous_shape =  tuple(output_shape)
        
        model = model[1:] # remove the current layer
        context_vectors.append([macs_conv, macs_lin, macs_act, n_conv, n_lin, n_act, intermediate_size])

        # calculate the next intermediate size considering early exits 
        # PROBLEM: MAY NOT BE ROBUST IF THE MODEL HAS MORE THAN ONE OUTPUT AND EACH OF THEM (or at least one) HAVE MORE THAN ONE LINEAL LAYER
        if n_classes_ee not in model_considered[list(model_considered.keys())[0]]['output_shape'][1:]: 
            intermediate_size  = 1
            for dim_out in model_considered[list(model_considered.keys())[0]]['output_shape'][1:]:
                intermediate_size *= dim_out
        
        previous_shape = tuple(model_considered[list(model_considered.keys())[0]]['output_shape'][1:]) # reset previous shape to the first layer considered

    context_vectors.append([0,0,0,0,0,0,0])
    return context_vectors


In [10]:
get_context_vectors(model, summary_model, input_size, n_classes_ee=4)

[[23190016, 363264, 290816, 4, 2, 4, 40000],
 [19190016, 363264, 290816, 3, 2, 4, 160000],
 [19190016, 363264, 130816, 3, 2, 3, 160000],
 [7670016, 363264, 130816, 2, 2, 3, 80000],
 [7670016, 363264, 50816, 2, 2, 2, 80000],
 [7670016, 43264, 50816, 2, 1, 2, 80000],
 [6230016, 43264, 50816, 1, 1, 2, 40000],
 [6230016, 43264, 10816, 1, 1, 1, 40000],
 [0, 43264, 10816, 0, 1, 1, 10816],
 [0, 43264, 0, 0, 1, 0, 10816],
 [0, 0, 0, 0, 0, 0, 0]]

In [11]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),  # Conv1
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),                # Pool1
            
            nn.Conv2d(6, 16, kernel_size=5, stride=1),            # Conv2
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),                # Pool2
            
            nn.Conv2d(16, 120, kernel_size=5, stride=1),          # Conv3
            nn.ReLU()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(480, 84),                                   # FC1
            nn.ReLU(),
            nn.Linear(84, num_classes)                           # FC2
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.classifier(x)
        return x

In [12]:
model = LeNet5()
input_size = (1, 32, 32)
summary_model = summary_to_dict(model, input_size)

In [13]:
get_context_vectors(model, summary_model, input_size, n_classes_ee=10)

[[691200, 41160, 9012, 3, 2, 4, 1024],
 [537600, 41160, 9012, 2, 2, 4, 6144],
 [537600, 41160, 2868, 2, 2, 3, 6144],
 [537600, 41160, 2868, 2, 2, 3, 1536],
 [192000, 41160, 2868, 1, 2, 3, 2304],
 [192000, 41160, 564, 1, 2, 2, 2304],
 [192000, 41160, 564, 1, 2, 2, 576],
 [0, 41160, 564, 0, 2, 2, 480],
 [0, 41160, 84, 0, 2, 1, 480],
 [0, 840, 84, 0, 1, 1, 84],
 [0, 840, 0, 0, 1, 0, 84],
 [0, 0, 0, 0, 0, 0, 0]]

In [14]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, num_classes),
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [15]:
model = AlexNet()
input_size = (3, 224, 224)
summary_model = summary_to_dict(model, input_size)

In [16]:
get_context_vectors(model, summary_model, input_size, n_classes_ee=1000)

[[655566528, 58621952, 493184, 5, 3, 7, 150528],
 [585289728, 58621952, 493184, 4, 3, 7, 193600],
 [585289728, 58621952, 299584, 4, 3, 6, 193600],
 [585289728, 58621952, 299584, 4, 3, 6, 46656],
 [361340928, 58621952, 299584, 3, 3, 6, 139968],
 [361340928, 58621952, 159616, 3, 3, 5, 139968],
 [361340928, 58621952, 159616, 3, 3, 5, 32448],
 [249200640, 58621952, 159616, 2, 3, 5, 64896],
 [249200640, 58621952, 94720, 2, 3, 4, 64896],
 [99680256, 58621952, 94720, 1, 3, 4, 43264],
 [99680256, 58621952, 51456, 1, 3, 3, 43264],
 [0, 58621952, 51456, 0, 3, 3, 43264],
 [0, 58621952, 8192, 0, 3, 2, 43264],
 [0, 58621952, 8192, 0, 3, 2, 9216],
 [0, 58621952, 8192, 0, 3, 2, 9216],
 [0, 20873216, 8192, 0, 2, 2, 4096],
 [0, 20873216, 4096, 0, 2, 1, 4096],
 [0, 20873216, 4096, 0, 2, 1, 4096],
 [0, 4096000, 4096, 0, 1, 1, 4096],
 [0, 4096000, 0, 0, 1, 0, 4096],
 [0, 0, 0, 0, 0, 0, 0]]