<a href="https://colab.research.google.com/github/karlmaji/pytorch_learning/blob/master/MobileNetV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def set_seed(seed):
  torch.manual_seed(seed)
  np.random.seed(seed)

  if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False


![img](https://pdf.cdn.readpaper.com/parsed/fetch_target/49903af9a58081720de11b4f6a317304_3_Figure_3.png)

![img](https://pdf.cdn.readpaper.com/parsed/fetch_target/49903af9a58081720de11b4f6a317304_3_Table_1.png)

In [39]:
def Depthwise_Separable_Conv(input_channel,output_channel,stride):
  block = nn.Sequential(
      nn.Conv2d(input_channel,input_channel,(3,3),stride=stride,padding=1,groups = input_channel,bias=False),
      nn.BatchNorm2d(input_channel),
      nn.ReLU(inplace =True),
      nn.Conv2d(input_channel,output_channel,(1,1),stride=1,bias=False),
      nn.BatchNorm2d(output_channel),
      nn.ReLU(inplace=True)
  )
  return block

class MobileNetV1(nn.Module):
  def __init__(self,input_channel=3,num_class=1000):
    super(MobileNetV1,self).__init__()
    self.Conv1_bn_relu = nn.Sequential(
        nn.Conv2d(input_channel,32,(3,3),stride=2,padding=1,bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True)
    )
    self.Conv_to_Conv = nn.Sequential(
        Depthwise_Separable_Conv(32,64,1),
        Depthwise_Separable_Conv(64,128,2),
        Depthwise_Separable_Conv(128,128,1),
        Depthwise_Separable_Conv(128,256,2),
        Depthwise_Separable_Conv(256,256,1),
        Depthwise_Separable_Conv(256,512,2),
        *[Depthwise_Separable_Conv(512,512,1) for i in range(5)],
        Depthwise_Separable_Conv(512,1024,2),
        Depthwise_Separable_Conv(1024,1024,1),
    )

    self.head = nn.Linear(1024,num_class)
    self._init_weights()
  def forward(self,x):
    x = self.Conv1_bn_relu(x)
    x = self.Conv_to_Conv(x)
    x = self.head(x.mean((-1,-2)))
    return x

  def _init_weights(self):
    for m in self.modules():
      if isinstance(m,nn.Conv2d):
        nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='relu')
      if isinstance(m,nn.BatchNorm2d):
        nn.init.constant_(m.weight,1)
        nn.init.constant_(m.bias,0)
      if isinstance(m , nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias,0)


model = MobileNetV1()