Define a MobileNet from scratch and train it on MNIST

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
KERNEL_SIZE = 3
PADDING = 1  # should be (KERNEL_SIZE - 1) / 2. Why: so that depthwise convs preserve channel shape, from in to out
FMAPS = [3, 16, 32, 64]
STRIDES = [1, 2, 2, 2]
FC_SIZE = 1024

LEARNING_RATE = 0.001
MOMENTUM = 0.9

## Logging

In [3]:
import logging

class LoggerMixin(object):
    @property
    def logger(self):
        name = '.'.join([
            self.__module__,
            self.__class__.__name__
        ])
        logger = logging.getLogger(name)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        chandler = logging.StreamHandler()
#         chandler.setLevel(logging.DEBUG)
        chandler.setFormatter(formatter)
#         logger.addHandler(chandler)
        logger.setLevel(logging.DEBUG)
        return logger


## Model

In [4]:
class MobileNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(MobileNetBlock, self).__init__()
        # depthwise - operates on each channel individually. input is one entire channel. 
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=1, groups=in_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        # pointwise - operates on each point individually. input is the channel vector at that point. 
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        return

    def forward(self, x):
        out = F.relu(self.bn1(self.depthwise(x)))
        out = F.relu(self.bn2(self.pointwise(out)))
        return out


### notes
- padding and stride in convolutions

In [5]:
class MobileNet(LoggerMixin, nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()  # to do: what does this do?
        self.conv1 = nn.Conv2d(3, FMAPS[0], kernel_size=KERNEL_SIZE, padding=1, stride=STRIDES[0], bias=False)
        self.bn1 = nn.BatchNorm2d(FMAPS[0])
        self.mobile_net_blocks = nn.Sequential(*[
            MobileNetBlock(FMAPS[i], FMAPS[i+1], KERNEL_SIZE, STRIDES[i])
            for i in range(0, len(FMAPS) - 1)
        ])
        self.fc1 = nn.Linear(FC_SIZE, 10)
        self.logger.debug("initializing model")
        return

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.mobile_net_blocks(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        return out


### notes
- in init...
  - nn.Linear for fully connected layers
  - Adding a normal conv before the MobileNet blocks
- in forward...
  - computing the first conv and batchnorm
  - pooling!
  - resizing the output of convolution blocks for the fully connected layer

In [6]:
model = MobileNet()
model

MobileNet(
  (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mobile_net_blocks): Sequential(
    (0): MobileNetBlock(
      (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
      (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): MobileNetBlock(
      (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, trac

In [7]:
from torchsummary import summary

In [8]:
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 3, 32, 32]              81
       BatchNorm2d-2            [-1, 3, 32, 32]               6
            Conv2d-3            [-1, 3, 32, 32]              27
       BatchNorm2d-4            [-1, 3, 32, 32]               6
            Conv2d-5           [-1, 16, 32, 32]              48
       BatchNorm2d-6           [-1, 16, 32, 32]              32
    MobileNetBlock-7           [-1, 16, 32, 32]               0
            Conv2d-8           [-1, 16, 16, 16]             144
       BatchNorm2d-9           [-1, 16, 16, 16]              32
           Conv2d-10           [-1, 32, 16, 16]             512
      BatchNorm2d-11           [-1, 32, 16, 16]              64
   MobileNetBlock-12           [-1, 32, 16, 16]               0
           Conv2d-13             [-1, 32, 8, 8]             288
      BatchNorm2d-14             [-1, 3

In [9]:
model(torch.randn(1, 3, 32, 32))

tensor([[0.2990, 0.0000, 0.0000, 0.0000, 0.2405, 0.0000, 0.0635, 0.1372, 0.1149,
         0.0830]], grad_fn=<ReluBackward0>)

## data

In [17]:
import torch.utils
import torchvision
import torchvision.transforms as transforms

In [18]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [19]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


## execution

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [23]:
model = model.to(device)
if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

In [24]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

In [27]:
model.train()
train_loss = 0
correct = 0
total = 0

class Tracker(object):
    def __init__(self):
        pass

    def keep_track(outputs, targets, loss, ):
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    
    tracker.keep_track()


In [11]:
model = MobileNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)  # to do: change to adam

NameError: name 'net' is not defined

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')