In [2]:
import torch
from torch import nn

#### Depthwise convolution 


In [16]:
# 3x3 Depthwise Conv 
class DepthWise_conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride): 
        super(DepthWise_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      in_channels,
                      kernel_size=3,
                      stride=stride,
                      padding=1,
                      groups=in_channels,
                      bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU())
    def forward(self, x: torch.Tensor):
        return self.conv(x)

### pointwise conv

In [15]:
class PointWise_conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PointWise_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size=1,  
                      stride=1,
                      bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())
    
    def forward(self, x: torch.Tensor):
        return self.conv(x)

### Depthwise Seperable

In [17]:
class DepthWiseSeperable_conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DepthWiseSeperable_conv, self).__init__()
        self.depthwise = DepthWise_conv(in_channels, out_channels, stride)
        self.pointwise = PointWise_conv(in_channels, out_channels)

    def forward(self, x : torch.Tensor):
        x1 = self.depthwise(x)              #x go through depthwise 
        x2 = self.pointwise(x1)             #x1 go through pointwise
        return x2


### MobileNet

In [18]:
class MobileNet(nn.Module):
    def __init__(self, num_classes = 1000):
        super(MobileNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels= 3,
                      out_channels= 32,
                      kernel_size= 3,
                      stride= 2,
                      bias= False),

            nn.ReLU(inplace=True),
            
            DepthWiseSeperable_conv(32, 64, 1),
            DepthWiseSeperable_conv(64, 128, 2),
            DepthWiseSeperable_conv(128, 128, 1),
            DepthWiseSeperable_conv(128, 256, 2),
            DepthWiseSeperable_conv(256, 256, 1),
            DepthWiseSeperable_conv(256,512, 2),

            DepthWiseSeperable_conv(512, 512, 1),
            DepthWiseSeperable_conv(512, 512, 1),
            DepthWiseSeperable_conv(512, 512, 1),
            DepthWiseSeperable_conv(512, 512, 1),
            DepthWiseSeperable_conv(512, 512, 1),

            DepthWiseSeperable_conv(512, 1024, 2),
            DepthWiseSeperable_conv(1024, 1024, 1),

            nn.AdaptiveAvgPool2d(1),

        )

        self.fc = nn.Linear(1024, 1000)  
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x : torch.Tensor):
        x1 = self.model(x)
        print(f"shape : {x1.shape}")
        x2 = x1.reshape(x1.size(0), -1)
        x3 = self.fc(x2)  
        x4 = self.softmax(x3)
        return x4
    
MobileNet_inst = MobileNet(num_classes=1000)

    

In [20]:
img_tensor = torch.rand(1, 3, 224, 224)

output = MobileNet_inst(img_tensor)
output

shape : torch.Size([1, 1024, 1, 1])


tensor([[0.0007, 0.0008, 0.0008, 0.0009, 0.0010, 0.0014, 0.0007, 0.0010, 0.0011,
         0.0010, 0.0007, 0.0010, 0.0008, 0.0011, 0.0011, 0.0010, 0.0009, 0.0008,
         0.0009, 0.0008, 0.0014, 0.0007, 0.0010, 0.0012, 0.0013, 0.0013, 0.0014,
         0.0010, 0.0013, 0.0008, 0.0011, 0.0010, 0.0008, 0.0012, 0.0011, 0.0009,
         0.0012, 0.0011, 0.0012, 0.0011, 0.0010, 0.0006, 0.0010, 0.0008, 0.0009,
         0.0009, 0.0010, 0.0008, 0.0007, 0.0008, 0.0010, 0.0013, 0.0008, 0.0014,
         0.0011, 0.0016, 0.0011, 0.0013, 0.0015, 0.0009, 0.0012, 0.0014, 0.0009,
         0.0011, 0.0013, 0.0012, 0.0012, 0.0007, 0.0011, 0.0007, 0.0011, 0.0007,
         0.0011, 0.0012, 0.0012, 0.0011, 0.0014, 0.0010, 0.0014, 0.0011, 0.0017,
         0.0011, 0.0007, 0.0008, 0.0013, 0.0010, 0.0008, 0.0009, 0.0015, 0.0013,
         0.0009, 0.0009, 0.0007, 0.0011, 0.0010, 0.0009, 0.0015, 0.0010, 0.0009,
         0.0007, 0.0011, 0.0014, 0.0011, 0.0011, 0.0014, 0.0009, 0.0011, 0.0009,
         0.0011, 0.0014, 0.0

In [21]:
num_parms = sum(p.numel() for p in MobileNet_inst.parameters() if p.requires_grad)
print(f'the no of trainable params : {num_parms}')


the no of trainable params : 4231912
