Spatial Transformer Network

For: data augmentation

In [None]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F

import matplotlib.pyplot as plt

STN components

In [None]:
# convolutional components
net_conv1 = nn.Conv2d(1, 10, kernel_size=5)
net_conv2 = nn.Conv2d(10, 20, kernel_size=5)
net_conv_drop = nn.Dropout2d()
net_fc1 = nn.Linear(320, 50)
net_fc2 = nn.Linear(50, 10)

In [None]:
# spatial transformer localization components
net_localization = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=7),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True),
    nn.Conv2d(8, 10, kernel_size=5),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True)
)

# regressor for 3x2 affine matrix
net_fc_loc = nn.Sequential(
    nn.Linear(10 * 3 * 3, 32),
    nn.ReLU(True),
    nn.Linear(32, 3 * 2)
)

Initialize weights

In [None]:
net_fc_loc[2].weight.data.zero_()
net_fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

Prepare MNIST dataset

In [None]:
BATCH_SIZE = 1

In [None]:
# train loader
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        '../data', train=True, download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ]),
    ),
    batch_size=BATCH_SIZE, shuffle=True,  
)

# test loader
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        '../data', train=False, download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ]),
    ),
    batch_size=BATCH_SIZE, shuffle=True,
)

Sample one batch from dataloader

In [None]:
sample_batch = next(iter(train_loader))
print(sample_batch[0].shape)

In [None]:
# show sample image
plt.imshow(sample_batch[0][0].numpy().squeeze(), cmap='gray')

STN forward propagation

In [None]:
img = sample_batch[0]
print(img.shape)

In [None]:
xs = net_localization(img)
print(xs.shape)

In [None]:
# make image grid composed from 10 channel of 3 x 3 grayscale images
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(xs[0, i, :, :].detach().numpy().squeeze(), cmap='gray')
plt.show()

In [None]:
# flatten
xs = xs.view(-1, 10 * 3 * 3)
print(xs.shape)

theta = net_fc_loc(xs)
print(theta.shape)

theta = theta.view(-1, 2, 3)
print(theta.shape)

In [None]:
# visualize theta
plt.imshow(theta.detach().numpy().squeeze(), cmap='gray')

In [None]:
grid = F.affine_grid(theta, img.size(), align_corners=True)
img_affine = F.grid_sample(img, grid, align_corners=True)
print(img_affine.shape)

In [None]:
# show image
plt.imshow(img_affine.detach().numpy().squeeze(), cmap='gray')