![alt text](dw.png)

In [1]:
import torch 
import matplotlib.pyplot as plt
from torch import nn


In [2]:
class depthWiseConvs(nn.Module):
    def __init__(self,inChannels,outChannels,stride):
        super().__init__()
        self.conv = nn.Sequential(
                    nn.Conv2d(inChannels,inChannels,
                                kernel_size=(3),
                                stride=stride,
                                padding=1,
                                groups=inChannels,
                                bias=False
                                ),
                                
                    nn.BatchNorm2d(inChannels),
                    nn.ReLU() )
                
    def forward(self, x):
        return self.conv(x)


In [3]:
class pointWiseConvs(nn.Module):
    def __init__(self,inChannels,outChannels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inChannels,outChannels,kernel_size=(1),stride=1,bias=False),
            nn.BatchNorm2d(outChannels),
            nn.ReLU()
        ) 

    def forward(self, x : torch.Tensor):
        return self.conv(x)


In [4]:
class depthWiseSeperableConvs(nn.Module):
    def __init__(self,inChannels,outChannels,stride):
        super().__init__()
        self.depthWise = depthWiseConvs(inChannels,outChannels,stride)
        self.pointWise = pointWiseConvs(inChannels,outChannels)
    def forward(self,x):
        out = self.depthWise(x)
        out1 = self.pointWise(out)
        return out1



In [5]:
class mobileNet(nn.Module):
    def __init__(self, numClasses=1000,):
        super().__init__()
        self.model = nn.Sequential(

            nn.Conv2d(in_channels=3,
                        out_channels=32,
                        kernel_size=3,
                        stride=2,
                        bias=True),

            nn.ReLU(inplace=True),

            depthWiseSeperableConvs(32,64,1),
            depthWiseSeperableConvs(64,128,2),
            depthWiseSeperableConvs(128,128,1),
            depthWiseSeperableConvs(128,256,2),
            depthWiseSeperableConvs(256,256,1),
            depthWiseSeperableConvs(256,512,2),

            depthWiseSeperableConvs(512,512,1),
            depthWiseSeperableConvs(512,512,1),
            depthWiseSeperableConvs(512,512,1),
            depthWiseSeperableConvs(512,512,1),
            depthWiseSeperableConvs(512,512,1),
            
            depthWiseSeperableConvs(512,1024,2),
            depthWiseSeperableConvs(1024,1024,1),

            nn.AdaptiveAvgPool2d(1),
        )

        self.fc = nn.Linear(1024,numClasses)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x : torch.Tensor):
        x1 = self.model(x)
        x2 = x1.reshape(x1.size(0),-1)
        x3 = self.fc(x2)
        x4 = self.softmax(x3)
        return x4


In [6]:
model  = mobileNet(2)


In [7]:
img = torch.randn(1,3,224,224)

model(img).shape


torch.Size([1, 2])

In [8]:
num_parms = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'the no of trainable params : {num_parms}')


the no of trainable params : 3208994


# test and train data 