In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Dataset
We run through the following steps:
1. Import the dataset from Kaggle and load into a `pd.DataFrame`
2. Extract the images and labels from the `pd.DataFrame`

We will use the encoding that:
- The digits "0"->"9" map to 0->9
- The alphabet "A"->"Z" map to 10->35

In [None]:
data = pd.read_csv("./data/A_Z Handwritten Data.csv", dtype="float32")
X = data.drop(columns="0").astype("float32").divide(255)
y = data["0"].astype("int64").add(10)

In [None]:
def label_to_string(y: int) -> str:
    if y < 10:
        return str(y)
    return chr(y - 10 + ord("A"))

In [None]:
import torchvision
from torch.utils.data import ConcatDataset, Dataset
from sklearn.model_selection import train_test_split


class HandwritingDataset(Dataset):
    def __init__(self, X: pd.DataFrame, y: pd.Series, transform=None, target_transform=None):
        self.X = X
        self.y = y
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        image = self.X.iloc[idx, :].values.reshape(28, 28)
        label = self.y.iloc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
transform = torchvision.transforms.ToTensor()
train_dataset = ConcatDataset(
    [
        HandwritingDataset(X_train, y_train, transform=transform),
        torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform),
    ]
)
test_dataset = ConcatDataset(
    [
        HandwritingDataset(X_test, y_test, transform=transform),
        torchvision.datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform),
    ]
)

In [None]:
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# put into batches
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

# Training a model

## Model definition
We will use a supervised-VAE with the following architecture.

### Encoder
This part of the model predicts the class labels and latent variables from a given image.
- Convolutional layer + ReLU (input size 28x28x1, output size 14x14x`num_filters`)
- Convolutional layer + ReLU (input size 14x14x`num_filters`, output size 7x7x`num_filters`)
- Flatten (input size 7x7x`num_filters`, output size 49x`num_filters`)
- To generate *class labels* `y`:
    - Linear + ReLU (input size 49x`num_filters`, output size 128)
    - Linear (input size 128, output size `output_size`)
    - **Note:** there is no `Softmax` here since this is implicitly included as part of the `CrossEntropyLoss` used later on
- To predict mean and variance of *latent variables* `z`:
    - Linear (input size 49x`num_filters`, output size `num_latent_var`)

### Decoder
This part of the model reconstructs an image from class labels and latent variables.
- From *class labels* `y` logits:
    - Softmax (to convert from logits to class probabilities)
    - Linear + ReLU (input size `num_latent_var`, output size 128)
    - Linear (input size 128, output size 49x`num_filters`)
- From *latent variables* `z`:
    - Linear (input size `num_latent_var`, output size 49x`num_filters`)
- Unflatten (input size 49x`num_filters`, output size 7x7x`num_filters`,)
- Deconvolutional layer + ReLU (input size 7x7x`num_filters`, output size 14x14x`num_filters`)
- Deconvolutional layer (input size 14x14x`num_filters`, output size 28x28x1)
- **Note:** there is no `Sigmoid` here since this is implicitly included as part of the `BCEWithLogitsLoss` used later on



In [None]:
from training.model import VAE

model = VAE(input_size=28, output_size=10 + 26, num_filters=32, num_latent_var=64).to(device)

## Loss
For a supervised VAE, we need 3 loss terms:
1. **Reconstruction loss** to enforce that the decoder can accuractely reconstruct characters from the latent variables (using `BCEWithLogitsLoss`)
2. **KL-loss** to enforce that the encoder accuractely predicts the posterior on the latent variables
3. **Categorical loss** to enforce that the model can accuractely classify characters (using `CrossEntropyLoss`)

1 & 2 would be needed for an unsupervised VAE as well, but 3 is an additional loss needed for the supervised nature of this problem.

In [None]:
cat_loss_fun = nn.CrossEntropyLoss()
recon_loss_fun = nn.BCEWithLogitsLoss()


def kl_div_loss_fun(z_mean: torch.Tensor, log_z_var: torch.Tensor) -> torch.Tensor:
    return -0.5 * torch.sum(1 + log_z_var - z_mean.pow(2) - log_z_var.exp()) / z_mean.shape[0]


def loss_fun(model: VAE, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    x_recon, z_mean, log_z_var, y_pred = model.forward(x)
    return cat_loss_fun(y_pred, y), recon_loss_fun(x_recon, x), kl_div_loss_fun(z_mean, log_z_var)

In [None]:
num_epochs = 10
optimiser = torch.optim.Adam(model.parameters())

for epoch in range(1, num_epochs + 1):
    minloss = 1
    running_kl_div_loss = 0
    running_recons_loss = 0
    running_cat_loss = 0
    num_images = 0
    for i, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)
        optimiser.zero_grad()
        cat_loss, recons_loss, kl_div_loss = loss_fun(model, img, label)
        loss = 0.1 * cat_loss + recons_loss + epoch * 0.001 * kl_div_loss
        loss.backward()
        optimiser.step()
        running_cat_loss = running_cat_loss + cat_loss.item() * len(img)
        running_recons_loss = running_recons_loss + recons_loss.item() * len(img)
        running_kl_div_loss = running_kl_div_loss + kl_div_loss.item() * len(img)

        num_images = num_images + len(img)
    print(
        'epoch: '
        + str(epoch)
        + ' cat_loss: '
        + str(running_cat_loss / num_images)
        + ' recons_loss: '
        + str(running_recons_loss / num_images)
        + ' kl_div_loss: '
        + str(running_kl_div_loss / num_images)
    )

## Save model to file
We need to save the model so that we can use it for inference.

In [None]:
torch.save(model.state_dict(), "model.pt")

## Assessing model accuracy
We must first evaluate the model on the test set.

We then can assess the accuracy in 2 ways:
- Compute the multi-class classifiction accuracy
- Compute the binary accuracy of the reconstructed images

In [None]:
x_truths = []
x_recons = []
z_means = []
log_z_vars = []
y_truths = []
y_preds = []

for im, y_true in test_loader:
    x_recon, z_mean, log_z_var, ysoft = model.forward(im.to(device))
    im_recon = x_recon.sigmoid().detach()
    _, y_pred = torch.max(ysoft, 1)
    x_truths.append(im)
    x_recons.append(im_recon.cpu())
    z_means.append(z_mean.cpu())
    log_z_vars.append(log_z_var.cpu())
    y_truths.append(y_true.cpu())
    y_preds.append(y_pred.cpu())

### Reconstruction accuracy

In [None]:
from torchmetrics.classification import BinaryAccuracy

recon_metric = BinaryAccuracy(threshold=0.5)
reconstructed_images = torch.cat(x_recons)
original_images = torch.cat(x_truths)
print("Recon. acc: {}".format(recon_metric(reconstructed_images, original_images > 0.5)))

We can also plot the accuracy over the image. We can see that the outer black pixels are clearly well predicted, but the white pixels with the character is located are less accurately predicted.

In [None]:
recon_metric_image = np.array(
    [
        [recon_metric(reconstructed_images[:, :, i, j], original_images[:, :, i, j] > 0.5) for j in range(28)]
        for i in range(28)
    ]
)
plt.imshow(recon_metric_image)
plt.colorbar()

### Categorisation accuracy

We can evaluate the total accuracy, as well as the accuracy on each character. We can see that all characters are very accurately classified apart from "0". This is likely because this can be each miscategorised with many different characters such as "O" or "D" when poorly written.

In [None]:
from torchmetrics.classification import MulticlassAccuracy

total_cat_metric = MulticlassAccuracy(num_classes=model.output_size, average="weighted")
cat_accuracy_total = total_cat_metric(torch.cat(y_preds), torch.cat(y_truths))
print("Cat. acc (total): {}".format(cat_accuracy_total))

cat_metric_by_class = MulticlassAccuracy(num_classes=model.output_size, average=None)
cat_metric_by_class.update(torch.cat(y_preds), torch.cat(y_truths))
_, ax = cat_metric_by_class.plot()
ax.legend(bbox_to_anchor=(1.1, 1.05))

We can also get a more qualitative assessment by comparing some of the original and reconstructed images. 
- The examples are all correctly characterised
- The reconstructed images are clearly of the correct character. However, the reconstructed images are noticeably blurrier, which is a common phenomenon with VAEs.

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid

fig = plt.figure(figsize=(8.0, 10.0))
grid = ImageGrid(fig, 111, nrows_ncols=(2, 10), axes_pad=(0.05, 0.5))

for ax, im, y_true in zip(grid[:10], x_truths[0], y_truths[0]):
    ax.imshow(im.squeeze(), cmap='gray')
    ax.set_title(label_to_string(y_true.numpy()))
grid[0].set_ylabel("Truth")

for ax, im, y_pred in zip(grid[10:], x_recons[0], y_preds[0]):
    ax.imshow(im.squeeze(), cmap='gray')
    ax.set_title(label_to_string(y_pred.numpy()))
grid[10].set_ylabel("Gen.")

plt.show()

# Generation of synthetic hand-writing
We can now use the VAE to generate some synthetic handwriting.
- Use the encoder to determine the hand-writing style (in terms of latent variables) from the first 6 items in the test set
- Use the decoder to generate the full set of characters for each of these styles

In [None]:
labels = torch.arange(0, model.output_size)
y_plot = nn.functional.one_hot(labels.reshape(-1, 1).tile(1, 6).flatten()).float().to(device)

z_mean_to_plot = z_means[0][:6].to(device)
log_z_var_to_plot = log_z_vars[0][:6].to(device)
z_plot = model.sample(z_mean_to_plot, log_z_var_to_plot).tile(model.output_size, 1)
generated_images = model.decode(z_plot, y_plot).sigmoid().reshape(-1, model.input_size, model.input_size).cpu().detach()

In [None]:
fig = plt.figure(figsize=(15.0, 10.0))
grid = ImageGrid(fig, 121, nrows_ncols=(model.output_size, 6), axes_pad=0.05)

for ax, im in zip(grid, generated_images):
    ax.imshow(im, cmap='gray')

for i, label in enumerate(labels):
    grid[i * 6].set_ylabel(label_to_string(label))

plt.show()