In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import OrderedDict

In [2]:
def _ensure_divisible(number, divisor, min_value=None):
    '''
    Ensure that 'number' can be 'divisor' divisible
    Reference from original tensorflow repo:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    '''
    if min_value is None:
        min_value = divisor
    new_num = max(min_value, int(number + divisor / 2) // divisor * divisor)
    if new_num < 0.9 * number:
        new_num += divisor
        
    return new_num

In [10]:
class H_sigmoid(nn.Module):
    '''
    hard sigmoid
    '''
    def __init__(self, inplace=True):
        super(H_sigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3, inplace=self.inplace) / 6

class H_swish(nn.Module):
    '''
    hard swish
    '''
    def __init__(self, inplace=True):
        super(H_swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x * F.relu6(x + 3, inplace=self.inplace) / 6

In [5]:
class SEModule(nn.Module):
    def __init__(self, in_channels_num, reduction_ratio=4):
        super(SEModule, self).__init__()

        if in_channels_num % reduction_ratio != 0:  # in_channels_num // reduction_ratio
            raise ValueError('in_channels_num must be divisible by reduction_ratio(default = 4)')
        
        # Squeeze
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # Excitation
        self.fc = nn.Sequential(
            # F.C ReLU
            nn.Linear(in_channels_num, in_channels_num // reduction_ratio, bias=False),  # 1x1xC -> 1x1xC/r
            nn.ReLU(inplace=True),

            # F.C hard
            nn.Linear(in_channels_num // reduction_ratio, in_channels_num, bias=False),  # 1x1xC/r -> 1x1xC
            H_sigmoid()
        )

    def forward(self, x):  # x : feature
        batch_size, channel_num, _, _ = x.size()
        y = self.avg_pool(x).view(batch_size, channel_num)
        y = self.fc(y).view(batch_size, channel_num, 1, 1)  # channel weights

        return x * y  # feature recallibration

In [6]:
class Bottleneck(nn.Module):

    def __init__(self, in_channels_num, exp_size, out_channels_num, kernel_size, stride, use_SE, NL, BN_momentum):
        '''
        use_SE: True or False -- use SE Module or not
        NL: Non-linearity, 'ReLU' or 'H Swish'
        in_channels_num : 입력 채널
        exp_size : 중간 단계에서 확장되는 채널 (MobileNetV2의 narrow->wide->narrow의 wide 부분)
        out_channels_num : 출력 채널 (<= exp_size)
        '''
        super(Bottleneck, self).__init__()

        assert stride in [1, 2]  # stride=1일 때만 skip-conn 가능

        NL = NL.upper()
        assert NL in ['RE', 'HS']

        use_HS = NL == 'HS'
        
        # Whether to use residual structure or not
        self.use_residual = (stride == 1 and in_channels_num == out_channels_num)
        
        # W.O expansion 
        if exp_size == in_channels_num:
            # Depthwise-Conv
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_channels=in_channels_num, out_channels=exp_size, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=in_channels_num, bias=False),
                nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),

                # SE Module
                SEModule(exp_size) if use_SE else nn.Sequential(),
                H_swish() if use_HS else nn.ReLU(inplace=True))
            
            # Linear Pointwise-Conv
            self.conv2 = nn.Sequential(
                nn.Conv2d(in_channels=exp_size, out_channels=out_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
                #nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
                nn.Sequential(OrderedDict([('lastBN', nn.BatchNorm2d(num_features=out_channels_num))])) if self.use_residual else
                    nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
            )

        # With expansion    
        else:
            # Pointwise-Conv for expansion
            self.conv1 = nn.Sequential(
                # Pointwise-Conv for expansion
                nn.Conv2d(in_channels=in_channels_num, out_channels=exp_size, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),
                H_swish() if use_HS else nn.ReLU(inplace=True))
            
            # Depthwise-Conv
            self.conv2 = nn.Sequential(
                # Depthwise Convolution
                nn.Conv2d(in_channels=exp_size, out_channels=exp_size, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=exp_size, bias=False),
                nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),

                # SE Module
                SEModule(exp_size) if use_SE else nn.Sequential(),
                H_swish() if use_HS else nn.ReLU(inplace=True),

                # Linear Pointwise-Conv
                nn.Conv2d(in_channels=exp_size, out_channels=out_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
                #nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
                nn.Sequential(OrderedDict([('lastBN', nn.BatchNorm2d(num_features=out_channels_num))])) if self.use_residual else
                    nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
            )

    def forward(self, x, expand=False):
        out1 = self.conv1(x)  # Depthwise-Conv with SE
        out = self.conv2(out1)  # Output
        
        if self.use_residual:
            out = out + x
        if expand:
            return out, out1
        else:
            return out

In [7]:
class MobileNetV3(nn.Module):
    '''
    
    '''
    def __init__(self, mode='small', classes_num=1000, input_size=224, width_multiplier=1.0, dropout=0.2, BN_momentum=0.1, zero_gamma=False):
        '''
        configs: setting of the model
        mode: type of the model, 'large' or 'small'
        '''
        super(MobileNetV3, self).__init__()

        mode = mode.lower()
        assert mode in ['large', 'small']
        s = 2
        if input_size == 32 or input_size == 56:
            # using cifar-10, cifar-100 or Tiny-ImageNet
            s = 1

        # Configuration of a MobileNetV3-Large Model
        if mode == 'large':
            # kernel_size, exp_size, out_channels_num, use_SE, NL, stride
            configs = [
                [3, 16, 16, False, 'RE', 1],
                [3, 64, 24, False, 'RE', s],
                [3, 72, 24, False, 'RE', 1],
                [5, 72, 40, True, 'RE', 2],
                [5, 120, 40, True, 'RE', 1],
                [5, 120, 40, True, 'RE', 1],
                [3, 240, 80, False, 'HS', 2],
                [3, 200, 80, False, 'HS', 1],
                [3, 184, 80, False, 'HS', 1],
                [3, 184, 80, False, 'HS', 1],
                [3, 480, 112, True, 'HS', 1],
                [3, 672, 112, True, 'HS', 1],
                [5, 672, 160, True, 'HS', 2],
                [5, 960, 160, True, 'HS', 1],
                [5, 960, 160, True, 'HS', 1]
            ]

        # Configuration of a MobileNetV3-Small Model    
        elif mode == 'small':
            # kernel_size, exp_size, out_channels_num, use_SE, NL, stride
            configs = [
                [3, 16, 16, True, 'RE', s],
                [3, 72, 24, False, 'RE', 2],
                [3, 88, 24, False, 'RE', 1],
                [5, 96, 40, True, 'HS', 2],
                [5, 240, 40, True, 'HS', 1],
                [5, 240, 40, True, 'HS', 1],
                [5, 120, 48, True, 'HS', 1],
                [5, 144, 48, True, 'HS', 1],
                [5, 288, 96, True, 'HS', 2],
                [5, 576, 96, True, 'HS', 1],
                [5, 576, 96, True, 'HS', 1]
            ]

        first_channels_num = 16
        last_channels_num = 1280 if mode == 'large' else 1024

        divisor = 8

        ## Feature extraction ##
        # Input layer
        input_channels_num = _ensure_divisible(first_channels_num * width_multiplier, divisor)
        last_channels_num = _ensure_divisible(last_channels_num * width_multiplier, divisor) if width_multiplier > 1 else last_channels_num

        feature_extraction_layers = []
        first_layer = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=input_channels_num, kernel_size=3, stride=s, padding=1, bias=False),
            nn.BatchNorm2d(num_features=input_channels_num, momentum=BN_momentum),
            H_swish()
        )
        feature_extraction_layers.append(first_layer)

        # Bottleneck layers
        for kernel_size, exp_size, out_channels_num, use_SE, NL, stride in configs:
            output_channels_num = _ensure_divisible(out_channels_num * width_multiplier, divisor)
            exp_size = _ensure_divisible(exp_size * width_multiplier, divisor)

            feature_extraction_layers.append(Bottleneck(input_channels_num, exp_size, output_channels_num, kernel_size, stride, use_SE, NL, BN_momentum))
            input_channels_num = output_channels_num
        
        # Last layer
        last_stage_channels_num = _ensure_divisible(exp_size * width_multiplier, divisor)
        last_stage_layer1 = nn.Sequential(
                nn.Conv2d(in_channels=input_channels_num, out_channels=last_stage_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(num_features=last_stage_channels_num, momentum=BN_momentum),
                H_swish()
            )
        feature_extraction_layers.append(last_stage_layer1)

        
        self.featureList = nn.ModuleList(feature_extraction_layers)

        # SE Module
        last_stage = []
        last_stage.append(nn.AdaptiveAvgPool2d(1))
        last_stage.append(nn.Conv2d(in_channels=last_stage_channels_num, out_channels=last_channels_num, kernel_size=1, stride=1, padding=0, bias=False))
        last_stage.append(H_swish())

        self.last_stage_layers = nn.Sequential(*last_stage)
        
        # Classification 
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(last_channels_num, classes_num)
        )

        # Initialize the weights
        self._initialize_weights(zero_gamma)

    def forward(self, x):
        for i in range(9):
            x = self.featureList[i](x)
        x = self.featureList[9](x)
        for i in range(10, len(self.featureList)):
            x = self.featureList[i](x)

        x = self.last_stage_layers(x)

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        
        return x

    def _initialize_weights(self, zero_gamma):
        '''
        Initialize the weights
        '''
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        if zero_gamma:
            for m in self.modules():
	            if hasattr(m, 'lastBN'):
	                nn.init.constant_(m.lastBN.weight, 0.0)

In [8]:
!pip install torchsummaryX

Collecting torchsummaryX
  Downloading torchsummaryX-1.3.0-py3-none-any.whl (3.6 kB)
Installing collected packages: torchsummaryX
Successfully installed torchsummaryX-1.3.0


In [16]:
from torchsummaryX import summary

width_multiplier = 1

model_small = MobileNetV3(mode='small', classes_num=10, input_size=32, width_multiplier=width_multiplier)
model_small.eval()
summary(model_small, torch.zeros((1, 3, 224, 224)))
print('MobileNetV3-Small-%.2f cifar10-summaryX\n' % width_multiplier)

                                                         Kernel Shape  \
Layer                                                                   
0_featureList.0.Conv2d_0                                [3, 16, 3, 3]   
1_featureList.0.BatchNorm2d_1                                    [16]   
2_featureList.0.H_swish_2                                           -   
3_featureList.1.conv1.Conv2d_0                          [1, 16, 3, 3]   
4_featureList.1.conv1.BatchNorm2d_1                              [16]   
5_featureList.1.conv1.2.AdaptiveAvgPool2d_avg_pool                  -   
6_featureList.1.conv1.2.fc.Linear_0                           [16, 4]   
7_featureList.1.conv1.2.fc.ReLU_1                                   -   
8_featureList.1.conv1.2.fc.Linear_2                           [4, 16]   
9_featureList.1.conv1.2.fc.H_sigmoid_3                              -   
10_featureList.1.conv1.ReLU_3                                       -   
11_featureList.1.conv2.Conv2d_0                    

  df_sum = df.sum()


In [None]:
x = torch.randn([1, 3, 224, 224])
summary(model_small, x)