In [10]:
import torch.nn as nn
import torch

class MobileNetV1(nn.Module):
    def __init__(self, ch_in, n_classes):
        super(MobileNetV1, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
                )
        
        def conv_dw_1(inp, oup, stride):
            return nn.Sequential(
                # dw
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),

                )
        def conv_pw_1(inp, oup, stride):
            return nn.Sequential(
                # pw
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),

                )
        

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                # dw
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),

                # pw
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
                )

        self.model = nn.Sequential(
            conv_bn(ch_in, 32, 2),
            conv_dw_1(32, 64, 1),
            conv_pw_1(32, 64, 1),
            
#             conv_dw(32, 64, 1),
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(1024, n_classes)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
            print(x.size())
        
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x



In [11]:
model = MobileNetV1(ch_in=3, n_classes=1000)
x = torch.randn(1, 3, 224, 224)

# Let's print it
model(x)

torch.Size([1, 32, 112, 112])
torch.Size([1, 64, 112, 112])
torch.Size([1, 128, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 28, 28])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1024, 7, 7])
torch.Size([1, 1024, 1, 1])


tensor([[-2.8609e-01,  6.1796e-02,  1.1068e-01, -1.2091e-01,  2.5382e-01,
          1.6703e-01,  1.4827e-01,  1.8367e-01, -1.4449e-01, -5.4212e-01,
         -2.2482e-02,  5.8198e-02, -3.0763e-01,  3.1021e-01,  4.6761e-01,
          2.6555e-01, -4.0350e-01,  1.1226e-01, -1.3069e-01, -5.0062e-01,
         -4.5845e-01, -5.8796e-02,  9.0302e-02, -3.6002e-02,  4.8029e-01,
          1.2245e-01, -2.4034e-01,  3.1523e-01, -1.7444e-01,  8.5946e-02,
          5.1488e-02,  1.6509e-01, -1.4716e-01, -1.6591e-01, -4.9232e-02,
          1.2449e-01, -6.9452e-02, -1.4568e-01,  3.6624e-01,  2.0602e-01,
         -5.2416e-02, -4.5046e-02, -2.9539e-01, -9.9957e-02, -3.2074e-01,
          1.3619e-01, -1.6479e-01, -2.2650e-01,  1.1530e-01,  3.0435e-01,
         -2.7842e-01, -8.2706e-02, -8.3590e-02, -1.6402e-01,  3.4598e-01,
          8.5705e-02, -1.1435e-01,  1.6272e-01,  1.2121e-01,  1.2676e-01,
          2.7765e-01, -1.7774e-01, -3.6725e-01, -2.0794e-01, -2.1455e-01,
          1.5054e-01, -7.8047e-02, -2.