In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

In [2]:
from networks.inception import InceptionResnet
from networks.simple_cnn import BaselineCNNClassifier
from networks.resnet_big import SupCEResNet, SupConResNet, LinearClassifier
import torch
from torchsummary import summary
from thop import profile

In [3]:
import time

def measure_time(model, device, num_trial = 10):
    model = model.to(device=device)
    x = torch.rand((1, 1, 29, 29), device=device)
    total_time = 0
    start_time = time.time()
    for i in range(num_trial):    
        out = model(x)
    total_time = time.time() - start_time
    return total_time*1000/num_trial

In [4]:
baseline_model = SupCEResNet(name='resnet18', num_classes=5)

In [13]:
measure_time(baseline_model, device='cpu', num_trial=1000)

45.92603635787964

In [12]:
measure_time(baseline_model, device='cuda', num_trial=1000)

6.185551881790161

In [7]:
incep = InceptionResnet(n_classes=5)

In [14]:
measure_time(incep, device='cpu', num_trial=1000)

25.710362434387207

In [15]:
measure_time(incep, device= 'cuda', num_trial=1000)

5.972587585449219

In [10]:
x = torch.rand((1, 1, 29, 29), device='cpu')
baseline_model = baseline_model.to(device='cpu')
macs, params = profile(baseline_model, inputs=(x, ))
print('MACs (G): ', macs/1000**3)
print('Params (M): ', params/1000**2)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs (G):  0.032560592
Params (M):  0.700533


In [11]:
x = torch.rand((1, 1, 29, 29), device='cpu')
incep = incep.to(device='cpu')
macs, params = profile(incep, inputs=(x, ))
print('MACs (G): ', macs/1000**3)
print('Params (M): ', params/1000**2)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs (G):  0.097190176
Params (M):  1.694181


In [16]:
summary(baseline_model.cuda(), (1, 29, 29))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 29, 29]             144
       BatchNorm2d-2           [-1, 16, 29, 29]              32
            Conv2d-3           [-1, 16, 29, 29]           2,304
       BatchNorm2d-4           [-1, 16, 29, 29]              32
            Conv2d-5           [-1, 16, 29, 29]           2,304
       BatchNorm2d-6           [-1, 16, 29, 29]              32
        BasicBlock-7           [-1, 16, 29, 29]               0
            Conv2d-8           [-1, 16, 29, 29]           2,304
       BatchNorm2d-9           [-1, 16, 29, 29]              32
           Conv2d-10           [-1, 16, 29, 29]           2,304
      BatchNorm2d-11           [-1, 16, 29, 29]              32
       BasicBlock-12           [-1, 16, 29, 29]               0
           Conv2d-13           [-1, 32, 15, 15]           4,608
      BatchNorm2d-14           [-1, 32,

In [17]:
summary(incep.cuda(), (1, 29, 29))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 29, 29]             320
            Conv2d-2           [-1, 32, 27, 27]           9,248
         MaxPool2d-3           [-1, 32, 13, 13]               0
            Conv2d-4           [-1, 64, 13, 13]           2,112
            Conv2d-5          [-1, 128, 13, 13]          73,856
            Conv2d-6          [-1, 128, 13, 13]         147,584
              Stem-7          [-1, 128, 13, 13]               0
            Conv2d-8           [-1, 32, 13, 13]           4,128
            Conv2d-9           [-1, 32, 13, 13]           4,128
           Conv2d-10           [-1, 32, 13, 13]           9,248
           Conv2d-11           [-1, 32, 13, 13]           4,128
           Conv2d-12           [-1, 32, 13, 13]           9,248
           Conv2d-13           [-1, 32, 13, 13]           9,248
           Conv2d-14          [-1, 128,