In [1]:
import sys
import os
import argparse

import torch

sys.path.insert(0, os.path.abspath('..'))

from src.models.san import SAN, Bottleneck
from tools.complexity import (get_model_complexity_info,
                              is_supported_instance,
                              flops_to_string,
                              get_model_parameters_number)

In [2]:
def collect_flops(model, units='GMac', precision=3):
    """Wrapper to collect flops and number of parameters at each layer"""
    total_flops = model.compute_average_flops_cost()

    def accumulate_flops(self):
        if is_supported_instance(self):
            return self.__flops__ / model.__batch_counter__
        else:
            sum = 0
            for m in self.children():
                sum += m.accumulate_flops()
            return sum

    def flops_repr(self):
        accumulated_flops_cost = self.accumulate_flops()
        return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
                          '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
                          self.original_extra_repr()])

    def add_extra_repr(m):
        m.accumulate_flops = accumulate_flops.__get__(m)
        flops_extra_repr = flops_repr.__get__(m)
        if m.extra_repr != flops_extra_repr:
            m.original_extra_repr = m.extra_repr
            m.extra_repr = flops_extra_repr
            assert m.extra_repr != m.original_extra_repr

    def del_extra_repr(m):
        if hasattr(m, 'original_extra_repr'):
            m.extra_repr = m.original_extra_repr
            del m.original_extra_repr
        if hasattr(m, 'accumulate_flops'):
            del m.accumulate_flops

    model.apply(add_extra_repr)
    # print(model, file=ost)

    # Retrieve flops and param at each layer and sub layer (2 levels)
    flops_dict, param_dict = {}, {}
    for i in model._modules.keys():
        item = model._modules[i]
        if isinstance(model._modules[i], torch.nn.modules.container.Sequential):
            for j in model._modules[i]._modules.keys():
                key = '{}-{}'.format(i, j)
                flops_dict[key] = item._modules[j].accumulate_flops()
                param_dict[key] = get_model_parameters_number(item._modules[j])
        else:
            flops_dict[i] = item.accumulate_flops()
            param_dict[i] = get_model_parameters_number(item)

    model.apply(del_extra_repr)
    return flops_dict, param_dict


def run_experiments(san_sa_type, san_layers, san_kernels):
    model = SAN(
        sa_type=san_sa_type,
        block=Bottleneck,
        layers=san_layers,
        kernels=san_kernels,
        num_classes=1000,  # Final fc will be removed later
    ).cuda()

    macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print()

    print('{:<15} {:>12} {:>12} {:>12} {:>12}'.format(
        'Layer', 'Flops (GMac)', 'Param (M)', 'AccFlops', 'AccParam'))
    print('-'*67)
    flops_dict, param_dict = collect_flops(model)
    total_flops, total_param = 0, 0
    for k in flops_dict:
        total_flops += flops_dict[k]*1e-9
        total_param += param_dict[k]*1e-6
        print('{:<15} {:>12.5f} {:>12.5f} {:>12.2f} {:>12.2f}'.format(
            k, flops_dict[k]*1e-9, param_dict[k]*1e-6, total_flops, total_param))

# SAN10 - pairwise

In [3]:
san_sa_type = 0  # Pairwise
san_layers = [2, 1, 2, 4, 1]
san_kernels = [3, 7, 7, 7, 7]

run_experiments(san_sa_type, san_layers, san_kernels)

Computational complexity:       2.16 GMac
Number of parameters:           10.53 M 

Layer           Flops (GMac)    Param (M)     AccFlops     AccParam
-------------------------------------------------------------------
conv_in              0.00963      0.00019         0.01         0.00
bn_in                0.00642      0.00013         0.02         0.00
conv0                0.05138      0.00410         0.07         0.00
bn0                  0.00161      0.00013         0.07         0.00
layer0-0             0.04601      0.00287         0.12         0.01
layer0-1             0.04601      0.00287         0.16         0.01
conv1                0.05138      0.01638         0.21         0.03
bn1                  0.00161      0.00051         0.21         0.03
layer1-0             0.22601      0.04245         0.44         0.07
conv2                0.10276      0.13107         0.54         0.20
bn2                  0.00080      0.00102         0.54         0.20
layer2-0             0.20642    

# SAN15 - pairwise

In [4]:
san_sa_type = 0  # Pairwise
san_layers = [3, 2, 3, 5, 2]
san_kernels = [3, 7, 7, 7, 7]

run_experiments(san_sa_type, san_layers, san_kernels)

Computational complexity:       3.02 GMac
Number of parameters:           14.07 M 

Layer           Flops (GMac)    Param (M)     AccFlops     AccParam
-------------------------------------------------------------------
conv_in              0.00963      0.00019         0.01         0.00
bn_in                0.00642      0.00013         0.02         0.00
conv0                0.05138      0.00410         0.07         0.00
bn0                  0.00161      0.00013         0.07         0.00
layer0-0             0.04601      0.00287         0.12         0.01
layer0-1             0.04601      0.00287         0.16         0.01
layer0-2             0.04601      0.00287         0.21         0.01
conv1                0.05138      0.01638         0.26         0.03
bn1                  0.00161      0.00051         0.26         0.03
layer1-0             0.22601      0.04245         0.49         0.07
layer1-1             0.22601      0.04245         0.71         0.11
conv2                0.10276    

# SAN19 - pairwise

In [5]:
san_sa_type = 0  # Pairwise
san_layers = [3, 3, 4, 6, 3]
san_kernels = [3, 7, 7, 7, 7]

run_experiments(san_sa_type, san_layers, san_kernels)

Computational complexity:       3.84 GMac
Number of parameters:           17.6 M  

Layer           Flops (GMac)    Param (M)     AccFlops     AccParam
-------------------------------------------------------------------
conv_in              0.00963      0.00019         0.01         0.00
bn_in                0.00642      0.00013         0.02         0.00
conv0                0.05138      0.00410         0.07         0.00
bn0                  0.00161      0.00013         0.07         0.00
layer0-0             0.04601      0.00287         0.12         0.01
layer0-1             0.04601      0.00287         0.16         0.01
layer0-2             0.04601      0.00287         0.21         0.01
conv1                0.05138      0.01638         0.26         0.03
bn1                  0.00161      0.00051         0.26         0.03
layer1-0             0.22601      0.04245         0.49         0.07
layer1-1             0.22601      0.04245         0.71         0.11
layer1-2             0.22601    