In [None]:
import torch
import numpy as np
from torchvision.models import *
#
import matplotlib.pyplot as plt

In [None]:
def get_model_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [None]:
model_classes = [
    ("VGG-16-wo-bn", vgg16, 90.382), # HORRIBLE PERFORMANCE AND HARD TO TRAIN
    ("VGG-19-wo-bn", vgg19, 90.876),
    ("VGG-16-w-bn", vgg16_bn, 91.516),
    ("VGG-19-w-bn", vgg19_bn, 91.842),
    ("ResNet-34", resnet34, 91.420),
    ("ResNet-50", resnet50, 92.862),
    ("ResNet-101", resnet101, 93.546),
    ("ResNet-152", resnet152, 94.046),
    ("Densenet-121", densenet121, 91.972),
    ("Densenet-169", densenet169, 92.806),
    ("Densenet-201", densenet201, 93.370),
    ("Densenet-161", densenet161, 93.560),
    #("Inception-v3", inception_v3),
    ("MobileNet-v2", mobilenet_v2, 90.286),
    ("MobileNet-v3-Large",mobilenet_v3_large, 91.340),
    ("ResNeXt-50-32x4d", resnext50_32x4d, 93.698),
    ("ResNeXt-101-32x8d", resnext101_32x8d, 94.526),
    ("Wide-ResNet-50-2", wide_resnet50_2, 94.086),
    ("Wide-ResNet-101-2", wide_resnet101_2, 94.284),
    ("MNASNet 1.0", mnasnet1_0, 91.510),
    
]

In [None]:
model_data = []
for model_name, model_class, model_acc in model_classes:
    model = model_class()
    model_params = get_model_params(model)
    model_data.append((model_name, model_params, model_acc))
    print("{:15}: {:,}".format(model_name, get_model_params(model)))

In [None]:
xx = np.arange(len(model_data))
data = sorted(model_data, key=lambda x: x[2])
accs = [x[2] for x in data]
names = [x[0] for x in data]
params = [x[1] for x in data]
#
# acc per params
acc_per_param = []
for acc, param in zip(accs, params):
    acc_per_param.append(acc / param)


In [None]:
# Create some mock data
fig, ax1 = plt.subplots()


# ACC
color = 'tab:red'
ax1.set_ylabel('ACC', color=color)
ax1.bar(xx, accs, color=color, alpha=0.5)
ax1.set_ylim(88, 95)
ax1.set_xticks(xx)
ax1.set_xticklabels(names, rotation=90)

# PARAMS
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('#Parms', color=color)  # we already handled the x-label with ax1
ax2.bar(xx, params, color=color, alpha=0.5)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()

In [None]:
plt.bar(xx, acc_per_param, log=False, alpha=0.5, color="green")
plt.xticks(xx, names, rotation='vertical')
plt.title("ACC / param")
plt.show()

- VGG seems to be horrible -> throw out
- Add Try MobileNetV2/V3, MNASNet1.0

In [None]:
for i in range(len(xx)):
    if acc_per_param[i] > 1e-5:
        print(names[i])