In [29]:
import torch
import torch.nn as nn

In [30]:
class MobileNetV1(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(MobileNetV1, self).__init__()
        self.models = nn.Sequential(
            self.conv_bn(in_channels, 32, 2), 
            self.conv_dw(32, 64, 1),
            self.conv_dw(64, 128, 2),
            self.conv_dw(128, 128, 1),
            self.conv_dw(128, 256, 2),
            self.conv_dw(256, 256, 1),
            self.conv_dw(256, 512, 2),
            self.conv_dw(512, 512, 1),
            self.conv_dw(512, 512, 1),
            self.conv_dw(512, 512, 1),
            self.conv_dw(512, 512, 1),
            self.conv_dw(512, 512, 1),
            self.conv_dw(512, 1024, 2),
            self.conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7)
        )

        self.fc_layer = nn.Linear(1024, num_classes)

    def conv_bn(self, in_channels, out_channels, stride):
        return nn.Sequential(StandardConv(in_channels, out_channels, stride))

    def conv_dw(self, in_channels, out_channels, stride):
        layers = []
        layers.append(DepthwiseConv(in_channels, in_channels, stride))
        layers.append(PointwiseConv(in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.models(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x

In [31]:
class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DepthwiseConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [32]:
class PointwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [33]:
class StandardConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(StandardConv, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [34]:
from torchsummary import summary

model = MobileNetV1()
summary(model, (3, 299, 299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 150, 150]             864
       BatchNorm2d-2         [-1, 32, 150, 150]              64
              ReLU-3         [-1, 32, 150, 150]               0
      StandardConv-4         [-1, 32, 150, 150]               0
            Conv2d-5         [-1, 32, 150, 150]             288
       BatchNorm2d-6         [-1, 32, 150, 150]              64
              ReLU-7         [-1, 32, 150, 150]               0
     DepthwiseConv-8         [-1, 32, 150, 150]               0
            Conv2d-9         [-1, 64, 150, 150]           2,048
      BatchNorm2d-10         [-1, 64, 150, 150]             128
             ReLU-11         [-1, 64, 150, 150]               0
    PointwiseConv-12         [-1, 64, 150, 150]               0
           Conv2d-13           [-1, 64, 75, 75]             576
      BatchNorm2d-14           [-1, 64,

## Reference
- [pytorch-mobilenet-v1](https://github.com/wjc852456/pytorch-mobilenet-v1)