In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available else 'cpu'
IMG_SIZE = 224

In [7]:
# SE Block
class SEBlock(nn.Module):
    def __init__(self, in_channels, r=16):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d((1,1))
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, in_channels//r),
            nn.ReLU(),
            nn.Linear(in_channels//r, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.squeeze(x)
        x = x.view(x.size(0), -1) 
        x = self.excitation(x)
        return x.view(x.size(0), x.size(1), 1, 1)

In [9]:
# Depthwise Separable Convolution
class Depthwise(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU6(),
        )

        self.pointwise = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU6(),
        )

        self.seblock = SEBlock(out_channels)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.seblock(x) * x

In [14]:
# Model - MobileNet V1
class MobileNet(nn.Module):
    def __init__(self, width_multiplier, num_classes=10, init_weights=True):
        super().__init__()
        self.init_weights=init_weights
        alpha = width_multiplier
    
        # Basic Conv
        self.conv_layer_1 = nn.Sequential(
            nn.Conv2d(3, int(32*alpha), kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(int(32*alpha)),
            nn.ReLU()
        )
        
        self.conv_layer_2 = Depthwise(int(32*alpha), int(64*alpha), stride=1)
        
        # down sample
        self.conv_layer_3 = nn.Sequential(
            Depthwise(int(64*alpha), int(128*alpha), stride=2),
            Depthwise(int(128*alpha), int(128*alpha), stride=1)
        )
        
        # down sample
        self.conv_layer_4 = nn.Sequential(
            Depthwise(int(128*alpha), int(256*alpha), stride=2),
            Depthwise(int(256*alpha), int(256*alpha), stride=1)
        )
        
        # down sample
        self.conv_layer_5 = nn.Sequential(
            Depthwise(int(256*alpha), int(512*alpha), stride=2),
            Depthwise(int(512*alpha), int(512*alpha), stride=1),
            Depthwise(int(512*alpha), int(512*alpha), stride=1),
            Depthwise(int(512*alpha), int(512*alpha), stride=1),
            Depthwise(int(512*alpha), int(512*alpha), stride=1),
            Depthwise(int(512*alpha), int(512*alpha), stride=1),
        )
        
        # down sample
        self.conv_layer_6 = nn.Sequential(
            Depthwise(int(512*alpha), int(1024*alpha), stride=2)
        )
        
        # down sample
        self.conv_layer_7 = nn.Sequential(
            Depthwise(int(1024*alpha), int(1024*alpha), stride=2)
        )

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc_layer = nn.Linear(int(1024*alpha), num_classes)

        # weights initialization
        if self.init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.conv_layer_1(x)
        x = self.conv_layer_2(x)
        x = self.conv_layer_3(x)
        x = self.conv_layer_4(x)
        x = self.conv_layer_5(x)
        x = self.conv_layer_6(x)
        x = self.conv_layer_7(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        return self.fc_layer(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [15]:
model = MobileNet(width_multiplier=1, num_classes=10).to(device)
x = torch.randn(10, 3, IMG_SIZE, IMG_SIZE).to(device)
model(x)

tensor([[ 0.0021,  0.0357,  0.0950, -0.0139, -0.0605,  0.0507,  0.0057,  0.1063,
          0.1422, -0.0256],
        [-0.0455,  0.0758,  0.1003, -0.0226,  0.0140,  0.0129, -0.0250,  0.1055,
          0.1140,  0.0038],
        [ 0.0050,  0.0702,  0.1219, -0.0434, -0.0218,  0.0253, -0.0117,  0.0845,
          0.1222,  0.0144],
        [ 0.0265,  0.0235,  0.1170, -0.0175, -0.0115, -0.0412, -0.0321,  0.0846,
          0.1110, -0.0755],
        [-0.0023,  0.0348,  0.1024, -0.0860,  0.0108, -0.0112,  0.0132,  0.0941,
          0.1382, -0.0342],
        [ 0.0087,  0.0703,  0.1140, -0.0419,  0.0204,  0.0085, -0.0274,  0.0853,
          0.1541, -0.0207],
        [ 0.0239,  0.0326,  0.0654, -0.0582,  0.0095, -0.0230, -0.0152,  0.0665,
          0.1489, -0.0499],
        [ 0.0115,  0.0379,  0.1636, -0.0363, -0.0177, -0.0018, -0.0176,  0.1083,
          0.1524, -0.0904],
        [-0.0252,  0.0895,  0.1325, -0.0723,  0.0135,  0.0224,  0.0129,  0.0903,
          0.1094, -0.0263],
        [-0.0087,  