In [None]:
# importing the libraries
import numpy as np

# for reading and displaying images
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch libraries and modules
import torch
import torch.nn as nn
import torch.nn.functional as F

## Dataset
In this toy example, we generate images of handwritten digits by contatenating MNIST digits

In [None]:
import torchvision
import torchvision.transforms as transforms
import random

MAX_DIGITS = 8
BLANK_SYMBOL = 10

transform = transforms.ToTensor()

train_set = torchvision.datasets.MNIST(root='/tmp/MNIST/', train=True, transform=transform, download=True)
test_set = torchvision.datasets.MNIST(root='/tmp/MNIST/', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=MAX_DIGITS,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=MAX_DIGITS,
    shuffle=True)

def sample_training_example(data_loader):
    # Sample a random number of digits and join them together
    images, class_labels = next(iter(data_loader))
    num_digits = random.randint(1, MAX_DIGITS)
    joined_image = images[:num_digits, :, :].permute(2, 0, 1, 3).reshape(1, 1, 28, -1)
    return joined_image, class_labels[:num_digits]
    
def generate_batch(data_loader, batch_size = 16):
    batch_data = [sample_training_example(data_loader) for x in range(batch_size)]
    images, targets = zip(*batch_data)
    max_width = max(image.shape[3] for image in images)
    sequence_lenghts = [target.shape[0] for target in targets]
    max_length = max(sequence_lenghts)
    
    images = [
        torch.nn.functional.pad(image, (0, max_width - image.shape[3]))
        for image in images
    ]
    targets = [
        torch.nn.functional.pad(target, (0, max_length - target.shape[0]), value=BLANK_SYMBOL)
        for target in targets
    ]
    return torch.cat(images), torch.stack(targets), sequence_lenghts


image, targets = sample_training_example(train_loader)
plt.imshow(image.squeeze())
plt.show()
print(f"Target Sequence  : \t{np.array(targets)}")

## Model
We define our model, which is an LSTM on top of a Convolutional Neural Network.  This model treats each image as a sequence of image slices and classifies the digits in these slices. 

In [None]:
from torch.nn import Linear, ReLU, Sequential, Conv2d, MaxPool2d, Module, BatchNorm2d, CrossEntropyLoss, Softmax
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class NumberReaderModel(nn.Module):
    def __init__(self, hidden_size=16, num_layers=1, n_classes=11):
        super().__init__()
        
        # CNN for feature extraction
        self.cnn_layers = Sequential(
            Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(num_features=16),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            BatchNorm2d(num_features=16),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=2, stride=2),
        )
        
        # LSTM for sequential predictions
        self.lstm = nn.LSTM(
            input_size=7 * 7 * 16,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )
        
        # Linear layer to classify the output
        self.linear = nn.Linear(
            in_features=hidden_size,
            out_features=n_classes
        )
         
    def forward(self, x):
        x, sequence_lengths = x
        x = self.cnn_layers(x)
        
        # We have to reshape the output of the CNN from 
        # (batches, channels, height, width) to (batches, sequence_length, dimensions)
        # for the LSTM
        batches, channels, height, width = x.shape
        
        # height is 7, width is 7 * N where N is the maximum digit count in the batch
        x = torch.reshape(x, (batches, channels, height, width // height, height))
        x = torch.transpose(x, 1, 3)
        x = torch.reshape(x, (batches, -1, channels * height * height))
        
        # this helper function pads the sequences to be the same length
        padded_sequence = pack_padded_sequence(x, sequence_lengths, batch_first=True, enforce_sorted=False)
        
        x, _ = self.lstm(padded_sequence)
        x, _ = pad_packed_sequence(x, batch_first=True)
        x = self.linear(x)
                            
        return F.log_softmax(x, dim=2)

## Train the model

In [None]:
model = NumberReaderModel()
loss_function = nn.CrossEntropyLoss(ignore_index=BLANK_SYMBOL)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for step in range(1000):
    optimizer.zero_grad()
    
    image, targets, sequence_lengths = generate_batch(train_loader)
    predictions = model((image, sequence_lengths))

    # we have to be careful about to match the predictions and targets
    loss = loss_function(predictions.permute(0, 2, 1), targets)
    loss.backward()
    optimizer.step()
    
    
    if step % 50 == 0:
        print(f"Step : {step}\t loss : {loss}")

In [None]:
image, targets = sample_training_example(test_loader)
plt.imshow(image.squeeze())
plt.show()

prediction = model((image, [image.shape[3] // 28]))
predicted_digits = np.argmax(prediction.detach().numpy()[0], axis=1)
print(f"Target Sequence  : \t{np.array(targets)}")
print(f"Model Prediction : \t{predicted_digits}")