In [113]:
from collections import OrderedDict

import torch
import torch.nn as nn

def get_same_padding(kernel_size):
    if isinstance(kernel_size, tuple):
        assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
        p1 = get_same_padding(kernel_size[0])
        p2 = get_same_padding(kernel_size[1])
        return p1, p2
    assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
    assert kernel_size % 2 > 0, 'kernel size should be odd number'
    return kernel_size // 2

class MBInvertedConvLayer(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, expand_ratio=6, mid_channels=None):
        super(MBInvertedConvLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = kernel_size
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.mid_channels = mid_channels

        if self.mid_channels is None:
            feature_dim = round(self.in_channels * self.expand_ratio)
        else:
            feature_dim = self.mid_channels

        if self.expand_ratio == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
                ('bn', nn.BatchNorm2d(feature_dim, affine=False, track_running_stats=False)),
                ('act', nn.ReLU6(inplace=True)),
            ]))

        pad = get_same_padding(self.kernel_size)
        self.depth_conv = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=feature_dim, bias=False)),
            ('bn', nn.BatchNorm2d(feature_dim, affine=False, track_running_stats=False)),
            ('act', nn.ReLU6(inplace=True)),
        ]))

        self.point_linear = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
            ('bn', nn.BatchNorm2d(out_channels, affine=False, track_running_stats=False)),
        ]))

    def forward(self, x):
        if self.inverted_bottleneck:
            x = self.inverted_bottleneck(x)
        x = self.depth_conv(x)
        x = self.point_linear(x)
        return x
    
class MBInvertedConvLayer_with_shortcut(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, expand_ratio=6, mid_channels=None):
        super(MBInvertedConvLayer_with_shortcut, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = kernel_size
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.mid_channels = mid_channels

        if self.mid_channels is None:
            feature_dim = round(self.in_channels * self.expand_ratio)
        else:
            feature_dim = self.mid_channels

        if self.expand_ratio == 1:
            self.inverted_bottleneck = None
        else:
            self.inverted_bottleneck = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
                ('bn', nn.BatchNorm2d(feature_dim, affine=False, track_running_stats=False)),
                ('act', nn.ReLU6(inplace=True)),
            ]))

        pad = get_same_padding(self.kernel_size)
        self.depth_conv = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=feature_dim, bias=False)),
            ('bn', nn.BatchNorm2d(feature_dim, affine=False, track_running_stats=False)),
            ('act', nn.ReLU6(inplace=True)),
        ]))

        self.point_linear = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
            ('bn', nn.BatchNorm2d(out_channels, affine=False, track_running_stats=False)),
        ]))

    def forward(self, x):
        skip_x = x
        if self.inverted_bottleneck:
            conv_x = self.inverted_bottleneck(x)
            conv_x = self.depth_conv(conv_x)
            conv_x = self.point_linear(conv_x)
            
        else:
            conv_x = self.depth_conv(x)
            conv_x = self.point_linear(conv_x)
        return skip_x + conv_x
    
class Zero_with_shortcut(nn.Module):

    def __init__(self):
        super(Zero_with_shortcut, self).__init__()

    def forward(self, x):
        return x
    
    
class MobileNetV2_search(nn.Module):
    def __init__(self, n_class=1000):
        super(MobileNetV2_search, self).__init__()
    
    
        width_stages = [24,40,80,96,192,320]
        n_cell_stages = [4,4,4,4,4,1]
        stride_stages = [2,2,2,1,2,1]
        
        input_channel = 32
        first_cell_width = 16
        
        self.first_conv = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(3, input_channel, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)),
            ('bn', nn.BatchNorm2d(32, affine=False, track_running_stats=False)),
            ('act', nn.ReLU6(inplace=True))
        ]))
        
        self.first_block_conv = MBInvertedConvLayer(input_channel, first_cell_width, 3, 1, 1)
        input_channel = first_cell_width
        
        self.features = torch.nn.ModuleList()
        for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):
            
            for i in range(n_cell):
                self.features.append(torch.nn.ModuleList())
                if i == 0:
                    stride = s
                else:
                    stride = 1
                # conv
                if stride == 1 and input_channel == width:                 
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=3, stride = stride, expand_ratio = 3)) 
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=3, stride = stride, expand_ratio = 6)) 
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=5, stride = stride, expand_ratio = 3))
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=5, stride = stride, expand_ratio = 6)) 
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=7, stride = stride, expand_ratio = 3)) 
                    self.features[-1].append(MBInvertedConvLayer_with_shortcut(input_channel, width, kernel_size=7, stride = stride, expand_ratio = 6)) 
                    self.features[-1].append(Zero_with_shortcut())
    
                else:
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=3, stride = stride, expand_ratio = 3)) 
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=3, stride = stride, expand_ratio = 6)) 
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=5, stride = stride, expand_ratio = 3))
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=5, stride = stride, expand_ratio = 6)) 
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=7, stride = stride, expand_ratio = 3)) 
                    self.features[-1].append(MBInvertedConvLayer(input_channel, width, kernel_size=7, stride = stride, expand_ratio = 6)) 
               
                input_channel = width
                
        last_channel = 1280
        self.feature_mix_layer = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(input_channel, last_channel, kernel_size=(1, 1), stride=(1, 1), bias=False)),
            ('bn', nn.BatchNorm2d(last_channel, affine=False, track_running_stats=False)),
            ('act', nn.ReLU6(inplace=True))
        ]))
        
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
        
        self.classifier = nn.Linear(in_features=1280, out_features=1000, bias=True)
        
        def forward(self, x):
            x = self.first_conv(x)
            x = self.first_block_conv(x)
#             for archs, arch_id in zip(self.features, architecture):
#                 x = archs[arch_id](x)
            x = self.feature_mix_layer(x)
            x = self.global_avg_pooling(x)
            x = x.view(x.size(0), -1)  # flatten
            x = self.classifier(x)
            return x
                


In [114]:
model = MobileNetV2_search()

In [15]:
model

MobileNetV2_search(
  (first_conv): Sequential(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (act): ReLU6(inplace=True)
  )
  (first_block_conv): MBInvertedConvLayer(
    (depth_conv): Sequential(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (act): ReLU6(inplace=True)
    )
    (point_linear): Sequential(
      (conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    )
  )
  (features): ModuleList(
    (0): ModuleList(
      (0): MBInvertedConvLayer(
        (inverted_bottleneck): Sequential(
          (conv): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn

In [18]:
layer_widths = []

width_stages = [24,40,80,96,192,320]
n_cell_stages = [4,4,4,4,4,1]
stride_stages = [2,2,2,1,2,1]

for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):
            
            for i in range(n_cell):
#                 self.features.append(torch.nn.ModuleList())
                if i == 0:
                    stride = s
                else:
                    stride = 1
                # conv
                if stride == 1 and input_channel == width:     #  --> 7          
                    layer_widths.append(7)
    
                else:                                         # --> 6
                    layer_widths.append(6)
               
                input_channel = width

In [204]:
layer_widths

[6, 7, 7, 7, 6, 7, 7, 7, 6, 7, 7, 7, 6, 7, 7, 7, 6, 7, 7, 7, 6]

In [42]:
expand_ratio = [3,6,3,6,3,6]
kernel_size = [3,3,5,5,7,7]


In [50]:
# Calculate # of params:
arch = [0]*21
layer_num = 0
num_params = 0
input_channel = 16
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):           
    for i in range(n_cell):
        num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
        num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
        num_params += input_channel * expand_ratio[arch[layer_num]] * width
        layer_num += 1
        input_channel = width

print(num_params)

1471920


In [51]:
arch = [1]*21
layer_num = 0
num_params = 0
input_channel = 16
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):           
    for i in range(n_cell):
        num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
        num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
        num_params += input_channel * expand_ratio[arch[layer_num]] * width
        layer_num += 1
        input_channel = width

print(num_params)

2943840


In [52]:
arch = [2]*21
layer_num = 0
num_params = 0
input_channel = 16
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):           
    for i in range(n_cell):
        num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
        num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
        num_params += input_channel * expand_ratio[arch[layer_num]] * width
        layer_num += 1
        input_channel = width

print(num_params)

1555632


In [53]:
arch = [3]*21
layer_num = 0
num_params = 0
input_channel = 16
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):           
    for i in range(n_cell):
        num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
        num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
        num_params += input_channel * expand_ratio[arch[layer_num]] * width
        layer_num += 1
        input_channel = width

print(num_params)

3111264


In [54]:
arch = [4]*21
layer_num = 0
num_params = 0
input_channel = 16
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):           
    for i in range(n_cell):
        num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
        num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
        num_params += input_channel * expand_ratio[arch[layer_num]] * width
        layer_num += 1
        input_channel = width

print(num_params)

1681200


In [147]:
def get_arch_num_params(arch):
    layer_num = 0
    num_params = 0
    input_channel = 16
    common_params = 3 * 32 * 9 + 32 * 9 + 32 * 16 + 320 * 1280 + 1280 * 1000 + 1000

    expand_ratio = [3,6,3,6,3,6]
    kernel_size = [3,3,5,5,7,7]

    width_stages = [24,40,80,96,192,320]
    n_cell_stages = [4,4,4,4,4,1]
    stride_stages = [2,2,2,1,2,1]

    for width, n_cell in zip(width_stages, n_cell_stages):           
        for i in range(n_cell):
            num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
            num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
            num_params += input_channel * expand_ratio[arch[layer_num]] * width
            layer_num += 1
            input_channel = width

    return num_params + common_params

In [63]:
get_arch_num_params([4]*21)


1681200

In [32]:
total_params

15818520

In [47]:
15818520/3362400

4.704532476802284

In [58]:
common_params = 0
common_params += 3 * 32 * 9 + 32 * 9 + 32 * 16 + 320 * 1280 + 1280 * 1000 + 1000

In [59]:
common_params

1692264

In [56]:
1471920 + 2943840 + 1555632 + 3111264 + 1681200 + 3362400

14126256

In [60]:
14126256 + common_params

15818520

In [116]:
from torchsummary import summary

summary(model, (1,224,224))

NotImplementedError: 

In [105]:
def get_rand_arch():
    rand_arch = []
    for i in range(len(layer_widths)):
        rand_arch.append(random.randrange(layer_widths[i]))
        
    return(rand_arch)

In [108]:
get_rand_arch()

[2, 2, 5, 0, 4, 5, 4, 2, 2, 2, 5, 0, 0, 5, 4, 4, 0, 6, 1, 0, 0]

In [None]:
3 * 32 * 9 + 32 * 9 + 32 * 16 + 320 * 1280 + 1280 * 1000 + 1000

In [182]:
def get_arch_flops(arch):
    input_size = 224
    total_param = 0
    
    
    layer_num = 0
    flops = 0
    input_channel = 32
    
    expand_ratio = [3,6,3,6,3,6]
    kernel_size = [3,3,5,5,7,7]

    width_stages = [24,40,80,96,192,320]
    n_cell_stages = [4,4,4,4,4,1]
    stride_stages = [2,2,2,1,2,1]
    
    # fisrt_conv
    kernel_ops = 3 * 3 * 3
    params = 32 * kernel_ops
    output_size = input_size // 2
    flops += params * output_size ** 2
    total_param += params
    
    input_size = output_size
    # first_conv_block
    kernel_ops = 1 * 3 * 3
    params = 32 * kernel_ops
    flops += params * output_size ** 2
    total_param += params
    
    kernel_ops = 32
    params = 16 * kernel_ops
    flops += params * output_size ** 2 
    total_param += params
    
    input_channel = 16
    
    
    for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):       
        
        for i in range(n_cell):
            if i == 0:
                stride = s
            else:
                stride = 1
            
            # bottleneck
            kernel_ops = input_channel
            params = input_channel * expand_ratio[arch[layer_num]] * kernel_ops
            output_size = input_size
            flops += params * output_size ** 2
            total_param += params
            
            # depthwise
            kernel_ops = kernel_size[arch[layer_num]] ** 2
            params = input_channel * expand_ratio[arch[layer_num]] * kernel_ops
            output_size = input_size // stride 
            flops += params * output_size ** 2
            total_param += params
            
            # separable
            kernel_ops = input_channel * expand_ratio[arch[layer_num]]
            params = width * kernel_ops
            flops += params * output_size ** 2
            total_param += params

            
            layer_num += 1
            input_channel = width
            input_size = output_size
            
    # feature_mix_layer
    kernel_ops = 320
    params = 1280 * kernel_ops
    flops += params * input_size ** 2
    total_param += params
    
    # linear
    params = 1280 * 1000
    flops += params
    total_param += params
    
    return total_param + 1000, flops
    

In [189]:
get_arch_flops([4]*21)[1] /1e6

324.407168

In [179]:
get_arch_num_params([0]*21)

3164184

In [134]:
len([3, 2, 3, 2, 2, 2, 3, 4, 0, 0, 3, 4, 2, 0, 5, 3, 5, 3, 5])

19

In [190]:
def get_arch_flops_fairnas(arch):
    input_size = 224
    total_param = 0
    
    
    layer_num = 0
    flops = 0
    input_channel = 32
    
    expand_ratio = [3,6,3,6,3,6]
    kernel_size = [3,3,5,5,7,7]

    width_stages = [24,40,80,96,192,320]
    n_cell_stages = [2,4,4,4,4,1]
    stride_stages = [2,2,2,1,2,1]
    
    # fisrt_conv
    kernel_ops = 3 * 3 * 3
    params = 32 * kernel_ops
    output_size = input_size // 2
    flops += params * output_size ** 2
    total_param += params
    
    input_size = output_size
    # first_conv_block
    kernel_ops = 1 * 3 * 3
    params = 32 * kernel_ops
    flops += params * output_size ** 2
    total_param += params
    
    kernel_ops = 32
    params = 16 * kernel_ops
    flops += params * output_size ** 2 
    total_param += params
    
    input_channel = 16
    
    
    for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):       
        
        for i in range(n_cell):
            if i == 0:
                stride = s
            else:
                stride = 1
            
            # bottleneck
            kernel_ops = input_channel
            params = input_channel * expand_ratio[arch[layer_num]] * kernel_ops
            output_size = input_size
            flops += params * output_size ** 2
            total_param += params
            
            # depthwise
            kernel_ops = kernel_size[arch[layer_num]] ** 2
            params = input_channel * expand_ratio[arch[layer_num]] * kernel_ops
            output_size = input_size // stride 
            flops += params * output_size ** 2
            total_param += params
            
            # separable
            kernel_ops = input_channel * expand_ratio[arch[layer_num]]
            params = width * kernel_ops
            flops += params * output_size ** 2
            total_param += params

            
            layer_num += 1
            input_channel = width
            input_size = output_size
            
    # feature_mix_layer
    kernel_ops = 320
    params = 1280 * kernel_ops
    flops += params * input_size ** 2
    total_param += params
    
    # linear
    params = 1280 * 1000
    flops += params
    total_param += params
    
    return total_param + 1000, flops
    

In [192]:
get_arch_flops_fairnas([3, 2, 3, 2, 2, 2, 3, 4, 0, 0, 3, 4, 2, 0, 5, 3, 5, 3, 5])[1] / 1e6

354.766784

In [193]:
get_arch_flops_fairnas([3, 2, 3, 2, 2, 2, 3, 4, 0, 0, 3, 4, 2, 0, 5, 3, 5, 3, 5])[0] / 1e6

4.488264

In [202]:
get_arch_flops([1]*21)[0] / 1e6

4.636104

In [203]:
get_arch_flops([1]*21)[1] / 1e6

472.6208

In [141]:
def get_arch_num_params_fairnas(arch):
    layer_num = 0
    num_params = 0
    input_channel = 16
    common_params = 3 * 32 * 9 + 32 * 9 + 32 * 16 + 320 * 1280 + 1280 * 1000 + 1000

    expand_ratio = [3,6,3,6,3,6]
    kernel_size = [3,3,5,5,7,7]

    width_stages = [32,40,80,96,192,320]
    n_cell_stages = [2,4,4,4,4,1]
    stride_stages = [2,2,2,1,2,1]

    for width, n_cell in zip(width_stages, n_cell_stages):           
        for i in range(n_cell):
            num_params += kernel_size[arch[layer_num]] ** 2 * input_channel * expand_ratio[arch[layer_num]] # Depthwise
            num_params += input_channel * input_channel * expand_ratio[arch[layer_num]]
            num_params += input_channel * expand_ratio[arch[layer_num]] * width
            layer_num += 1
            input_channel = width

    return num_params + common_params

In [178]:
get_arch_num_params([0]*21)

3164184

In [163]:
get_arch_flops([5]*21) / 1e6

712.161024

In [207]:
for archs, arch_id in zip(model.features, [0]*21):
    print(archs[arch_id])

MBInvertedConvLayer(
  (inverted_bottleneck): Sequential(
    (conv): Conv2d(16, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (act): ReLU6(inplace=True)
  )
  (depth_conv): Sequential(
    (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (act): ReLU6(inplace=True)
  )
  (point_linear): Sequential(
    (conv): Conv2d(48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
)
MBInvertedConvLayer_with_shortcut(
  (inverted_bottleneck): Sequential(
    (conv): Conv2d(24, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (act): ReLU6(inplace=True)
  )
  (depth_c