In [11]:
import torch
import torch.nn as nn

In [107]:
class DepthwiseSeparable(nn.Module):
  def __init__(self,in_channel,out_channel,stride):
    super(DepthwiseSeparable,self).__init__()

    self.stride=1
    if stride==True:
      self.stride=2

    self.block=nn.Sequential(
        nn.Conv2d(in_channels=in_channel,out_channels=in_channel,kernel_size=3,stride=self.stride,padding=1,bias=False),
        nn.BatchNorm2d(in_channel),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=1,stride=1,bias=False),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )

  def forward(self,x):    
    print(x.shape)
    return self.block(x)


In [125]:
class MobileNetwork(nn.Module):
  def __init__(self,alpha=1.0,num_classes=1000):
    super(MobileNetwork,self).__init__()

    # There are 27 convolution layer 1 layer is a fully convolutional layer and the
    # rest of 26 can be grouped in a block of depthwise seperable block, 
    stride=torch.zeros(13).bool()
    index=[1,3,5,11,12]
    stride[index]=True
    out_channel=32
    in_channel=32
    self.alpha=alpha
    # we can implement alpha by multiplying it to the in_channel and out_channel respectively.

    self.conv1=nn.Conv2d(in_channels=3,out_channels=out_channel,kernel_size=3,stride=2,padding=1,bias=False)
    self.bn1=nn.BatchNorm2d(out_channel)
    self.relu=nn.ReLU(inplace=True)

    self.features=nn.Sequential()

    # since the out_channel doubles when ever there is a stride, that is reduction of feature map
    # but in the first block without reduction of feature map there is doubling of channel so, 
    # the i==0 is there in the logic. Also the 12th block have a stride of 2 that is reduction in
    # feature map by half but in the last block the out_channel did not get double so in the logic
    # 'is not 12' is there. 
    for i in range(13):
      if (stride[i]==True or i==0) and i is not 12:
        out_channel=out_channel*2
      self.features.add_module('SepConvBlock%d'%i,DepthwiseSeparable(in_channel,out_channel,stride[i]))
      in_channel=out_channel

    self.avgpool=nn.AdaptiveAvgPool2d(1)
    self.classifier=nn.Linear(out_channel,num_classes)

  def forward(self,x):    
    out=self.conv1(x)   
    out=self.bn1(out)   
    out=self.relu(out)   
    out=self.features(out)
    out=self.avgpool(out)
    out=out.view(out.shape[0],-1)
    out=self.classifier(out)
    return out
    

In [126]:
def MobileNet():
  return MobileNetwork(alpha=1.0)

In [127]:
model=MobileNet()

In [128]:
def test():
  x=torch.randn(1,3,224,224)
  y=model(x)
  print(y.shape)

test()

torch.Size([1, 32, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 128, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1000])


In [10]:
down=torch.zeros(13).bool()
index=[1,3,5,11,12]
down[index]=True
for i in range(13):
  if down[i]==True or i==0:
    print(i)

0
1
3
5
11
12


In [None]:
model

MobileNetwork(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (features): Sequential(
    (SepConvBlock0): DepthwiseSeparable(
      (block): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (SepConvBlock1): DepthwiseSeparable(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   