In [1]:
import sys
sys.path.append('../code')

import torchvision.models as models
import torch
import pandas as pd

from src.utils.macs import calc_macs

# torchvision의 주요 모델

In [2]:
MODEL_MAP_FUNCTION = {
    'alexnet' : models.alexnet,
    'vgg11' : models.vgg11,
    'vgg11_bn' : models.vgg11_bn,
    'resnet18' : models.resnet18,
    'squeezenet1_0' : models.squeezenet1_0,
    'squeezenet1_1' : models.squeezenet1_1,
    'densenet121' : models.densenet121,
    # 'inception_v3' : models.inception_v3,
    # 'googlenet' : models.googlenet,
    'shufflenet_v2_x0_5' : models.shufflenet_v2_x0_5,
    'mobilenet_v2' : models.mobilenet_v2,
    'mobilenet_v3_small' : models.mobilenet_v3_small,
    'resnext50_32x4d' : models.resnext50_32x4d,
    'wide_resnet' : models.wide_resnet50_2,
    'mnasnet' : models.mnasnet0_5,
}

# torchvision 모델의 param 수, MACs

In [3]:
model_table = {'model_name': [], 'param_num': [], 'MACs': []}

for model_name, model_func in MODEL_MAP_FUNCTION.items():
    print(model_name)
    
    model = model_func()
    param_num = 0
    for params in model.parameters(): param_num += len(params)
    try:
        macs = calc_macs(model, (3, 32, 32))
    except RuntimeError:
        macs = 'Not_32'
    model_table['model_name'].append(model_name)
    model_table['param_num'].append(param_num)
    model_table['MACs'].append(macs)

model_table = pd.DataFrame(model_table, 
                           columns=['model_name', 'param_num', 'MACs'])

print(model_table)

alexnet
vgg11
vgg11_bn
resnet18
squeezenet1_0
squeezenet1_1
densenet121
shufflenet_v2_x0_5
mobilenet_v2
mobilenet_v3_small
resnext50_32x4d
wide_resnet
mnasnet
            model_name  param_num         MACs
0              alexnet      20688       Not_32
1                vgg11      23888  276684264.0
2             vgg11_bn      29392  276987368.0
3             resnet18      16400   37642728.0
4        squeezenet1_0       7952    9740464.0
5        squeezenet1_1       7888    3563472.0
6          densenet121      95888   59514344.0
7   shufflenet_v2_x0_5      13928    1878088.0
8         mobilenet_v2      53168    7665704.0
9   mobilenet_v3_small      27992    3246704.0
10     resnext50_32x4d     104336   88949736.0
11         wide_resnet     104336  235226088.0
12             mnasnet      32936    3511752.0


# 주요 모델들 dense_layer 축소

In [4]:
model_table = {'model_name': [], 'param_num': [], 'MACs': []}

def append_model_table(model, model_name):
    param_num = 0
    for params in model.parameters(): param_num += len(params)
    macs = calc_macs(model, (3, 32, 32))
    print(model_name, param_num, macs)
    model_table['model_name'].append(model_name)
    model_table['param_num'].append(param_num)
    model_table['MACs'].append(macs)

def copy_weight(model, pretrained_model):
    params = list(zip(model.parameters(), pretrained_model.parameters()))[:-2]

    with torch.no_grad():
        for param, pretrained_param in params:
            param.data = pretrained_param.data.clone().detach()


In [5]:
model = models.shufflenet_v2_x0_5(num_classes=9)
pretrained_model = models.shufflenet_v2_x0_5(pretrained=True)

copy_weight(model, pretrained_model)
append_model_table(model, 'shufflenet_v2_x0_5')

shufflenet_v2_x0_5 11946 862313.0


In [6]:
model = models.squeezenet1_1(num_classes=9)
pretrained_model = models.squeezenet1_1(pretrained=True, progress=False)

copy_weight(model, pretrained_model)
append_model_table(model, 'SqueezeNet1_1')

SqueezeNet1_1 5906 3054098.0


In [7]:
model = models.mobilenet_v3_small(num_classes=9)
pretrained_model = models.mobilenet_v3_small(pretrained=True, progress=False)

copy_weight(model, pretrained_model)
append_model_table(model, 'mobilenet_v3_small')

mobilenet_v3_small 26010 2230929.0


In [8]:
model = models.mnasnet0_5(num_classes=9)
pretrained_model = models.mnasnet0_5(pretrained=True, progress=False)

copy_weight(model, pretrained_model)
append_model_table(model, 'mnasnet')

mnasnet 30954 2242281.0


In [35]:
pd.DataFrame(model_table, columns=['model_name', 'param_num', 'MACs'])

Unnamed: 0,model_name,param_num,MACs
0,shufflenet_v2_x0_5,11946,862313.0
1,SqueezeNet1_1,5906,3054098.0
2,mobilenet_v3_small,26010,2230929.0
3,mnasnet,30954,2242281.0


In [16]:
model = models.ShuffleNetV2([4, 8, 4], [24, 48, 96, 128, 512], num_classes=9)
calc_macs(model, (3, 32, 32))

672233.0