<center><h1>MobileNet V2</h1></center>

<center><p><a href="http://arxiv.org/abs/1704.04861">MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications</a></p></center>
<center><p><a href="http://arxiv.org/abs/1801.04381">MobileNetV2: Inverted Residuals and Linear Bottlenecks</a></p></center>

<img src="https://machinethink.net/images/mobilenet-v2/ResidualBlock@2x.png" width="600"/>

In [1]:
from functools import partial

import torch
from torch import nn
import torch.nn.functional as F

# Blocks

## Basic Block

In [2]:
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super().__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU6(inplace=True),
        )

## Inverted Residual Block

<img src="https://user-images.githubusercontent.com/18547241/53912465-997fd900-401e-11e9-82b0-c0be0f2abf93.png" width="800"/>

In [3]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super().__init__()
        hidden_channels = in_channels * expand_ratio
        self.use_shortcut = stride == 1 and in_channels == out_channels

        layers = []
        if expand_ratio != 1:
            # 1x1 point-wise conv
            layers.append(ConvBNReLU(in_channels, hidden_channels, kernel_size=1))
        layers.extend([
            # 3x3 depth-wise conv
            ConvBNReLU(hidden_channels, hidden_channels, stride=stride, groups=hidden_channels),
            # 1x1 point-wise conv(linear)
            nn.Conv2d(hidden_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        ])

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_shortcut:
            return x + self.conv(x)
        else:
            return self.conv(x)

# MobileNet V2

<img src="https://miro.medium.com/v2/resize:fit:1016/1*5iA55983nBMlQn9f6ICxKg.png" width="600"/>

In [4]:
class MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, alpha=1., round_nearest=8):
        super().__init__()
        input_channels = _make_divisible(32 * alpha, round_nearest)
        last_channels = _make_divisible(1280 * alpha, round_nearest)

        inverted_residual_setting = [
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # conv2d
        features = [ConvBNReLU(3, input_channels, stride=2)]
        # bottleneck
        # e.g. bottleneck-1: t=1, c=16, n=1, s=1
        # InvertedResidual(32, 16, 1, 1): 112x112x32->(CBR)->112x112x32->(CB)->112x112x16
        # e.g. bottleneck-2: t=6, c=24, n=2, s=2
        # InvertedResidual(16, 24, 2, 6): 112x112x16->(CBR)->112x112x96->(CBR)->56x56x96->(CB)->56x56x24
        # InvertedResidual(16, 24, 1, 6): 56x56x24->(CBR)->56x56x114->(CBR)->56x56x114->(CB)->56x56x24
        for t, c, n, s in inverted_residual_setting:
            output_channels = _make_divisible(c * alpha, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(
                    InvertedResidual(input_channels, output_channels, stride=stride, expand_ratio=t)
                )
                input_channels = output_channels
        # conv2d 1x1
        features.append(ConvBNReLU(input_channels, last_channels, kernel_size=1))
        self.features = nn.Sequential(*features)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channels, num_classes),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

# Utils

In [5]:
def _make_divisible(v: float, divisor=8, min_value=None) -> int:
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8.
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

# Summary

## Data

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = torch.randn((32, 3, 224, 224)).to(device)

## MobileNet V2

In [7]:
from torchkeras import summary

net = MobileNetV2().to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                          [-1, 32, 112, 112]                  864
BatchNorm2d-2                     [-1, 32, 112, 112]                   64
ReLU6-3                           [-1, 32, 112, 112]                    0
Conv2d-4                          [-1, 32, 112, 112]                  288
BatchNorm2d-5                     [-1, 32, 112, 112]                   64
ReLU6-6                           [-1, 32, 112, 112]                    0
Conv2d-7                          [-1, 16, 112, 112]                  512
BatchNorm2d-8                     [-1, 16, 112, 112]                   32
Conv2d-9                          [-1, 96, 112, 112]                1,536
BatchNorm2d-10                    [-1, 96, 112, 112]                  192
ReLU6-11                          [-1, 96, 112, 112]                    0
Conv2d-12                           [