# Spatial Transformer Network
This is a guide for creating a Convolutional Neural Network with a Spatial Transformer Network (STN) for **image classification**. <br>
The guide is for personal use as a reference for future work regarding image classification problems.

The guide is heavily influenced by the official PyTorch tutorial found at https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html and the STN is first proposed by Google DeepMind in this paper https://arxiv.org/abs/1506.02025.

STNs is a visual attention mechanism that allow ANNs to learn to perform spatial transformation to images, to enhance the geometric invariance of the model. This is useful for CNNs as they are not affected by transformation of the images as known in image augmentaion.

The STN architecture consists of three main components:
- A **localization network** which is a regular CNN that finds the transformation parameters.
- A **grid generator** which generates a set of coordinates in the input image.
- A **sampler** which uses the transformation parameters and applies them to the input image.

The figure below shows the architecture of a SPT network.
<img src="./static_files/stn-architecture.png"/>

In this guide we will use the CIFAR-10 dataset. The steps of the guide are as follows:
1. Load the CIFAR-10 dataset using *torchvision*
2. Define a *Convolutional Neural Network*
3. Define a loss function for the CNN
4. Train the network
5. Test the network
6. Visualize the STN results

## 1. Load the data
The CIFAR-10 dataset consists images of the size 3x32x32, which means they are RGB color images with 3-channels and are of the size 32x32. There are ten different classes in the dataset as seen in the figure below.

<img src="./static_files/cifar10.png"/>

We will load the dataset easily using ``torchvision``.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# Defines
N_CHANNELS = 3
DIM_HEIGHT = DIM_WIDTH = 32
N_CLASSES = 10

The output of torchvision datasets are PILImage images of range [0, 1].
We transform them to Tensors of normalized range [-1, 1]

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', 
                                        train=True,
                                        download=True, 
                                        transform=transforms.Compose(
                                            [transforms.ToTensor(),
                                             transforms.Normalize((0.5, 0.5, 0.5),
                                                                  (0.5, 0.5, 0.5))]
                                        ))

# Test dataset
testset = torchvision.datasets.CIFAR10(root='./data', 
                                       train=False,
                                       download=True, 
                                       transform=transforms.Compose(
                                           [transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5),
                                                                 (0.5, 0.5, 0.5))]
                                       ))

Files already downloaded and verified
Files already downloaded and verified


In [3]:
print("Training data\n(samples, height, width, channels)")
print(trainset.train_data.shape)


print("\nTest data\n(samples, height, width, channels)")
print(testset.test_data.shape)

Training data
(samples, height, width, channels)
(50000, 32, 32, 3)

Test data
(samples, height, width, channels)
(10000, 32, 32, 3)


## 2. Define a Convolutional Neural Network

In [68]:
class Net(nn.Module):
    
    def __init__(self, input_channels, input_height, input_width, num_classes):
        super(Net, self).__init__()
        self.input_channels = input_channels
        self.input_height = input_height
        self.input_width = input_width
        self.num_classes = num_classes
        
        # spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(
                in_channels=input_channels,
                out_channels=8,
                kernel_size=7,
                stride=1,
                padding=3),
            nn.MaxPool2d(
                kernel_size=2,
                stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=8,
                out_channels=12,
                kernel_size=5,
                stride=2,
                padding=2),
            nn.MaxPool2d(
                kernel_size=2,
                stride=2),
            nn.ReLU(inplace=True)
        )
        
        # regressor for the 3 * 2 affine matrix that we use 
        # to make the bilinear interpolation for the spatial transformer
        self.fc_loc = nn.Sequential(
            nn.Linear(
                in_features=input_height//4 * input_width//4 * input_channels,
                out_features=32,
                bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(
                in_features=32,
                out_features=3 * 2,
                bias=True))
        
        ## network for the classification based on the transformed input image
        # convolutional layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, 
                      out_channels=16, 
                      kernel_size=5, 
                      stride=1, 
                      padding=2),
            nn.MaxPool2d(kernel_size=2,
                         stride=2),
            nn.ReLU(inplace=True))
            
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=32,
                      kernel_size=5,
                      stride=1,
                      padding=2),
            nn.MaxPool2d(kernel_size=2,
                         stride=2),
            nn.Dropout2d(p=0.2),
            nn.ReLU(inplace=True))
        
        # fully connected layers
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=2048,
                      out_features=50,
                      bias=True),
            nn.Dropout2d(p=0.2),
            nn.ReLU(inplace=True))
        
        self.fc_out = nn.Sequential(
            nn.Linear(in_features=50,
                      out_features=num_classes,
                      bias=False))
        
        self.fc_loc[2].weight.data.fill_(0)
        self.fc_loc[2].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])
        
        
    def stn(self, x):
        """ Spatial Transformer Network """
        
        # find the transformation parameters
        xs = self.localization(x)

        # transformation tensor must be same size as in_features of regressor (fc_loc)
        # notice the channels of the image are important
        # and multiplying with a factor of 10 is only to increase the number of features in the first layer
        xs = xs.view(-1, self.input_height//4 * self.input_width//4 * self.input_channels) 
        
        # input batch of affine matrices (N x 2 x 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3) # the size -1 (N) is inferred from other dimensions

        # generate coordinates for the input image
        grid = F.affine_grid(theta, x.size())
        
        # apply transformation parameters to input iamge
        x = F.grid_sample(x, grid)

        return x
        
    def forward(self, x):
        """ Forward Pass """
        # perform transformation of input image
        x = self.stn(x)

        # usual forward pass - classification network
        x = self.conv1(x)
        x = self.conv2(x)
        print(x.size())
        x = x.view(-1, 2048)
        x = self.fc1(x)
        x = self.fc_out(x)

        return F.log_softmax(x, dim=1)
        
net = Net(N_CHANNELS, DIM_HEIGHT, DIM_WIDTH, N_CLASSES).to(device)
print(net)

Net(
  (localization): Sequential(
    (0): Conv2d(3, 8, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace)
    (3): Conv2d(8, 12, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU(inplace)
  )
  (fc_loc): Sequential(
    (0): Linear(in_features=192, out_features=32, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=32, out_features=6, bias=True)
  )
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Dropout2d(p=0.2)
    (3): ReLU

## 2.1 Test Forward Pass
We test the networks forward pass on dummy data

In [69]:
x = np.random.normal(0, 1, (5, 3, 32, 32)).astype('float32')
x = torch.autograd.Variable(torch.from_numpy(x)).to(device)
output = net(x)
print([x.size() for x in output])

torch.Size([5, 32, 8, 8])
[torch.Size([10]), torch.Size([10]), torch.Size([10]), torch.Size([10]), torch.Size([10])]
