<a href="https://colab.research.google.com/github/moh2236945/pytorch_classification/blob/master/models/MobileNet_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import shutil
from collections import OrderedDict
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
class DSConv(nn.Module):
    
    def __init__(self, f_3x3, f_1x1, stride=1, padding=0):
        super(DSConv, self).__init__()
        
        self.feature = nn.Sequential(OrderedDict([
            ('dconv', nn.Conv2d(f_3x3,
                                f_3x3,
                                kernel_size=3,
                                groups=f_3x3,
                                stride=stride,
                                padding=padding,
                                bias=False
                                )),
            ('bn1', nn.BatchNorm2d(f_3x3)),
            ('act1', nn.ReLU()),
            ('pconv', nn.Conv2d(f_3x3,
                                f_1x1,
                                kernel_size=1,
                                bias=False)),
            ('bn2', nn.BatchNorm2d(f_1x1)),
            ('act2', nn.ReLU())
        ]))
    
    def forward(self, x):
        out = self.feature(x)
        return out

class MobileNet(nn.Module):
    """
        MobileNet-V1 architecture for CIFAR-10.
    """
    def __init__(self, channels, width_multiplier=1.0, num_classes=1000):
        super(MobileNet, self).__init__()
        
        channels = [int(elt * width_multiplier) for elt in channels]
        
        self.conv = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(3, channels[0], kernel_size=3,
                               stride=2, padding=1, bias=False)),
            ('bn', nn.BatchNorm2d(channels[0])),
            ('act', nn.ReLU()) 
        ]))
        
        self.features = nn.Sequential(OrderedDict([
            ('dsconv1', DSConv(channels[0], channels[1], 1, 1)),
            ('dsconv2', DSConv(channels[1], channels[2], 2, 1)),
            ('dsconv3', DSConv(channels[2], channels[2], 1, 1)),
            ('dsconv4', DSConv(channels[2], channels[3], 2, 1)),
            ('dsconv5', DSConv(channels[3], channels[3], 1, 1)),
            ('dsconv6', DSConv(channels[3], channels[4], 2, 1)),
            ('dsconv7_a', DSConv(channels[4], channels[4], 1, 1)),
            ('dsconv7_b', DSConv(channels[4], channels[4], 1, 1)),
            ('dsconv7_c', DSConv(channels[4], channels[4], 1, 1)),
            ('dsconv7_d', DSConv(channels[4], channels[4], 1, 1)),
            ('dsconv7_e', DSConv(channels[4], channels[4], 1, 1)),
            ('dsconv8', DSConv(channels[4], channels[5], 2, 1)),
            ('dsconv9', DSConv(channels[5], channels[5], 1, 1))
        ]))
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.linear = nn.Linear(channels[5], num_classes)
       
        
    def forward(self, x):
        out = self.conv(x)
        out = self.features(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out

def MobileNetV1():
    return MobileNet(channels=[32, 64, 128, 256, 512, 1024], width_multiplier=1)



In [3]:
model = MobileNetV1()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
summary(model, (3, 96, 96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 48, 48]             864
       BatchNorm2d-2           [-1, 32, 48, 48]              64
              ReLU-3           [-1, 32, 48, 48]               0
            Conv2d-4           [-1, 32, 48, 48]             288
       BatchNorm2d-5           [-1, 32, 48, 48]              64
              ReLU-6           [-1, 32, 48, 48]               0
            Conv2d-7           [-1, 64, 48, 48]           2,048
       BatchNorm2d-8           [-1, 64, 48, 48]             128
              ReLU-9           [-1, 64, 48, 48]               0
           DSConv-10           [-1, 64, 48, 48]               0
           Conv2d-11           [-1, 64, 24, 24]             576
      BatchNorm2d-12           [-1, 64, 24, 24]             128
             ReLU-13           [-1, 64, 24, 24]               0
           Conv2d-14          [-1, 128,