In [1]:
# Goal(s)
# 1. need a trainable implementation of MobileNet v1 (in PyTorch)
# 2. Compare params with a network using normal conv

import torch

In [2]:
def print_model_size(model, disp=True):
    
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Total number of trainable parameters: {}'.format(pytorch_total_params))
    

In [3]:
class TransformerNet(torch.nn.Module):
    
    def __init__(self):
        
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(torch.nn.Module):
    
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [4]:
net = TransformerNet()
print_model_size(net)

Total number of trainable parameters: 1679235


In [7]:
# Same transform net, but implemented with Depthwise Separable Convolutions
import torch
# this class implements the DwConv block introduced in MobileNet v1 paper
class MobileNetConvBlock(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=0):

        super(MobileNetConvBlock, self).__init__()

        self.depthwise = torch.nn.Conv2d(in_channels, in_channels, stride=stride, kernel_size=kernel_size, padding=pad, groups=in_channels)
        #self.mnb_bn1 = torch.nn.BatchNorm2d(in_channels)
        self.mnb_relu1 = torch.nn.ReLU()
        self.pointwise = torch.nn.Conv2d(in_channels, out_channels, stride=1, kernel_size=1)
        #self.mnb_bn2 = torch.nn.BatchNorm2d(out_channels)
        self.mnb_relu2 = torch.nn.ReLU()

    def forward(self, x):

        out = self.depthwise(x)
        #out = self.mnb_bn1(out)
        out = self.mnb_relu1(out)
        out = self.pointwise(out)
        #out = self.mnb_bn2(out)
        out = self.mnb_relu2(out)

        return out

class DSC_TransformerNet(torch.nn.Module):

    def __init__(self):

        super(DSC_TransformerNet, self).__init__()

        # Initial convolution layers
        self.conv1 = DSC_ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = DSC_ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = DSC_ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)

        # Residual layers
        self.res1 = DSC_ResidualBlock(128)
        self.res2 = DSC_ResidualBlock(128)
        self.res3 = DSC_ResidualBlock(128)
        self.res4 = DSC_ResidualBlock(128)
        self.res5 = DSC_ResidualBlock(128)

        # Upsampling Layers
        self.deconv1 = DSC_UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)

        self.deconv2 = DSC_UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)

        self.deconv3 = DSC_ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):

        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class DSC_ConvLayer(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride):

        super(DSC_ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = MobileNetConvBlock(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):

        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class DSC_ResidualBlock(torch.nn.Module):

    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, channels):

        super(DSC_ResidualBlock, self).__init__()
        self.conv1 = DSC_ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = DSC_ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class DSC_UpsampleConvLayer(torch.nn.Module):

    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):

        super(DSC_UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = MobileNetConvBlock(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):

        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [8]:
dscnet = DSC_TransformerNet()
print_model_size(dscnet)

Total number of trainable parameters: 207865


In [66]:
# o = [i + 2*p - k]/s + 1
# P = (F - 1)/2 when S=1

class MobileNetDepthWiseConv(torch.nn.Module):

    def __init__(self, n_in, n_out, stride=1):
        
        super(MobileNetDepthWiseConv, self).__init__()
        
        self.depthwise = torch.nn.Conv2d(n_in, n_in, stride=stride, kernel_size=3, padding=1, groups=n_in)
        self.mnb_bn1 = torch.nn.BatchNorm2d(n_in)
        self.mnb_relu1 = torch.nn.ReLU()
        
        self.pointwise = torch.nn.Conv2d(n_in, n_out, stride=1, kernel_size=1)
        self.mnb_bn2 = torch.nn.BatchNorm2d(n_out)
        self.mnb_relu2 = torch.nn.ReLU()

    def forward(self, x):
        
        out = self.depthwise(x)
        out = self.mnb_bn1(out)
        out = self.mnb_relu1(out)
        
        out = self.pointwise(out)
        out = self.mnb_bn2(out)
        out = self.mnb_relu2(out)
        
        return out

class MobileNetv1(torch.nn.Module):
    
    def __init__(self):
        
        super(MobileNetv1, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, stride=2, kernel_size=3)
        self.bn1   = torch.nn.BatchNorm2d(32)
        self.relu1 = torch.nn.ReLU()
        
        self.mn1   = MobileNetDepthWiseConv(n_in=32, n_out=64, stride=1)
        self.mn3   = MobileNetDepthWiseConv(n_in=64, n_out=128, stride=2)
        self.mn4   = MobileNetDepthWiseConv(n_in=128, n_out=128, stride=1)
        
        self.mn5   = MobileNetDepthWiseConv(n_in=128, n_out=256, stride=2)
        self.mn6   = MobileNetDepthWiseConv(n_in=256, n_out=256, stride=1)
        
        self.mnX   = MobileNetDepthWiseConv(n_in=256, n_out=512, stride=2)
        
        self.mn7   = MobileNetDepthWiseConv(n_in=512, n_out=512, stride=1)
        self.mn8   = MobileNetDepthWiseConv(n_in=512, n_out=512, stride=1)
        self.mn9   = MobileNetDepthWiseConv(n_in=512, n_out=512, stride=1)
        self.mn10   = MobileNetDepthWiseConv(n_in=512, n_out=512, stride=1)
        self.mn11   = MobileNetDepthWiseConv(n_in=512, n_out=512, stride=1)
        
        self.mn12   = MobileNetDepthWiseConv(n_in=512, n_out=1024, stride=2)
        self.mn13   = MobileNetDepthWiseConv(n_in=1024, n_out=1024, stride=1)
        
        self.avgpool1 = torch.nn.AvgPool2d(kernel_size=7, stride=1)
        self.fc1      = torch.nn.Linear(1024, 1000)
        self.sm       = torch.nn.Softmax(1000)
        # softmax
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.mn1(x)
        x = self.mn2(x)
        x = self.mn4(x)
        x = self.mn5(x)
        x = self.mn6(x)
        x = self.mnX(x)
        
        x = self.mn7(x)
        x = self.mn8(x)
        x = self.mn9(x)
        x = self.mn10(x)
        x = self.mn11(x)
        
        x = self.mn12(x)
        x = self.mn13(x)
        
        x = self.avgpool1(x)
        x = self.fc1(x)
        x = self.sm(x)
        
        
        
    

In [67]:
mobilenet = MobileNetv1()
print_model_size(mobilenet)

Total number of trainable parameters: 4242920
