# 👖 Variational Autoencoders - Fashion-MNIST

In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset.

The code has been adapted from the excellent [VAE tutorial](https://keras.io/examples/generative/vae/) created by Francois Chollet, available on the Keras website.

In [None]:
import os, sys
from dotenv import load_dotenv

load_dotenv()
python_path = os.getenv('PYTHONPATH')
data_path = os.getenv('DATA_PATH')
if python_path:
    for path in python_path.split(os.pathsep):
        if path not in sys.path:
            sys.path.append(path)


import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from tqdm import tqdm, trange

from notebooks.pt_utils import display

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 32
BATCH_SIZE = 100
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
EPOCHS = 5
BETA = 400

NUM_WOERKERS = 24

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. Prepare the data <a name="prepare"></a>

In [None]:
class AutoencoderDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        data, _ = self.dataset[idx]
        return data, data

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Pad(padding=2),
])

train_dataset = datasets.FashionMNIST(root=data_path, train=True, transform=transform)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])
test_dataset = datasets.FashionMNIST(root=data_path, train=False, transform=transform)

vae_train_dataset = AutoencoderDataset(train_dataset)
vae_val_dataset = AutoencoderDataset(val_dataset)
vae_test_dataset = AutoencoderDataset(test_dataset)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')

vae_train_loader = DataLoader(vae_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
vae_val_loader = DataLoader(vae_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
vae_test_loader = DataLoader(vae_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')

In [None]:
images, labels = next(iter(train_loader))
display(images)
print(labels[:10])

In [None]:
print(f'image size: {images[0].size()}')
print(f'image min: {torch.min(images).item()}')
print(f'image max: {torch.max(images).item()}')

## 2. Build the variational autoencoder <a name="build"></a>

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size, latent_size):
        super().__init__()
        c, w, h = input_size
        z_size = latent_size
        flat_size = 128 * 4 * 4
        unflat_size = (128, 4, 4)

        class SampleZ(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, inputs):
                z_mean, z_log_var = inputs
                epsilon = torch.randn_like(z_mean)
                return z_mean + torch.exp(0.5 * z_log_var) * epsilon
            
        class Encoder(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    nn.Conv2d(c, 32, 3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(32, 64, 3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Conv2d(64, 128, 3, stride=2, padding=1),
                    nn.ReLU(),
                    nn.Flatten(),
                )
                
                self.z_mean_branch = nn.Linear(flat_size, z_size)
                self.z_logvar_branch = nn.Linear(flat_size, z_size)
                self.z_sampler = SampleZ()
            
            def forward(self, x):
                x = self.seq(x)
                z_mean = self.z_mean_branch(x)
                z_logvar = self.z_logvar_branch(x)
                z = self.z_sampler([z_mean, z_logvar])
                return z, z_mean, z_logvar
        
        class Decoder(nn.Module):
            def __init__(self):
                super().__init__()

                self.seq = nn.Sequential(
                    nn.Linear(z_size, flat_size),
                    nn.Unflatten(dim=1, unflattened_size=unflat_size),
                    nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
                    nn.ReLU(),
                    nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                    nn.ReLU(),
                    nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
                    nn.ReLU(),
                    nn.Conv2d(32, 1, 3, stride=1, padding='same'),
                    nn.Sigmoid(),
                )
            
            def forward(self, x):
                return self.seq(x)


        self.encoder = Encoder()
        self.decoder = Decoder()

    
    def forward(self, x):
        z, _, _ = self.encoder(x)
        return self.decoder(z)


    def _construction_loss(self, original, reconstructed, beta=BETA):
        return beta * F.binary_cross_entropy(reconstructed, original)

    def _kl_div_loss(self, z_mean, z_logvar):
        return torch.mean(
            torch.sum(
                -0.5 * (1 + z_logvar - torch.square(z_mean) - torch.exp(z_logvar)), 
                dim=1
            )
        )
    
    def fit(self, train_loader, val_loader, epochs, optimizer, beta=BETA):

        with trange(epochs, desc='Epochs', unit='epoch', leave=True) as pbar:
            for epoch in range(epochs):


                # Train
                self.encoder.train()
                self.decoder.train()
                for inputs, labels in train_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    z, z_mean, z_logvar = self.encoder(inputs)
                    reconstructions = self.decoder(z)
                    loss = (
                        self._construction_loss(inputs, reconstructions, beta) + 
                        self._kl_div_loss(z_mean, z_logvar)
                    )
                    loss.backward()
                    optimizer.step()


                # Validate
                val_loss = 0.0
                self.encoder.eval()
                self.decoder.eval()
                with torch.no_grad():
                    for vb, (inputs, labels) in enumerate(val_loader):
                        inputs, labels = inputs.to(device), labels.to(device)
                        z, z_mean, z_logvar = self.encoder(inputs)
                        reconstructions = self.decoder(z)
                        loss = (
                            self._construction_loss(inputs, reconstructions, beta) + 
                            self._kl_div_loss(z_mean, z_logvar)
                        )
                        val_loss += loss.item()

                val_loss /= (vb + 1)
                postfix_str = f'Loss: {val_loss:0.4f}'
                pbar.set_postfix_str(postfix_str)
                pbar.update()



## 3. Train the variational autoencoder <a name="train"></a>

In [None]:
vae = VariationalAutoencoder((1, 32, 32), 2).to(device)
optimizer = torch.optim.Adam(params=vae.parameters(), lr=0.0005)

vae.fit(
    train_loader=vae_train_loader,
    val_loader=vae_val_loader,
    epochs=EPOCHS,
    optimizer=optimizer,
    beta=BETA
)


## 3. Reconstruct using the variational autoencoder <a name="reconstruct"></a>

In [None]:
ex_loader = DataLoader(test_dataset, batch_size=5000, shuffle=True, num_workers=24)
example_images, example_labels = next(iter(ex_loader))

In [None]:
# Create autoencoder predictions and display
reconstructions = vae(example_images.to(device)).detach().cpu()
print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(reconstructions)

## 4. Embed using the encoder <a name="encode"></a>

In [None]:
# Encode the example images
z, z_mean, z_logvar = vae.encoder(example_images.to(device))

In [None]:
# Some examples of the embeddings
z = z.detach().cpu()
print(z[:10])

In [None]:
# Show the encoded points in 2D space
figsize = 8

plt.figure(figsize=(figsize, figsize))
plt.scatter(z[:, 0], z[:, 1], c="black", alpha=0.5, s=3)
plt.show()

## 5. Generate using the decoder <a name="decode"></a>

In [None]:
# Sample some points in the latent space, from the standard normal distribution
grid_width, grid_height = (6, 3)
z_sample = torch.randn((grid_width * grid_height, 2))

In [None]:
# Decode the sampled points
reconstructions = vae.decoder(z_sample.to(device)).detach().cpu()

In [None]:
# Convert original embeddings and sampled embeddings to p-values
p = norm.cdf(z)
p_sample = norm.cdf(z_sample)

In [None]:
# Draw a plot of...
figsize = 8
plt.figure(figsize=(figsize, figsize))

# ... the original embeddings ...
plt.scatter(z[:, 0], z[:, 1], c="black", alpha=0.5, s=2)

# ... and the newly generated points in the latent space
plt.scatter(z_sample[:, 0], z_sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.show()

# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.text(
        0.5,
        -0.35,
        str(np.round(z_sample[i, :].numpy(), 1)),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.imshow(reconstructions[i, :, :].permute(1, 2, 0), cmap="Greys")

## 6. Explore the latent space <a name="explore"></a>

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 8
fig = plt.figure(figsize=(figsize * 2, figsize))
ax = fig.add_subplot(1, 2, 1)
plot_1 = ax.scatter(
    z[:, 0], z[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
)
plt.colorbar(plot_1)
ax = fig.add_subplot(1, 2, 2)
plot_2 = ax.scatter(
    p[:, 0], p[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=3
)
plt.show()

| ID | Clothing Label |
| :- | :- |
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
plt.figure(figsize=(figsize, figsize))
# plt.scatter(
#     p[:, 0], p[:, 1], cmap="rainbow", c=example_labels, alpha=0.8, s=300
# )
# plt.colorbar()

x = norm.ppf(np.linspace(0, 1, grid_size))
y = norm.ppf(np.linspace(1, 0, grid_size))
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
grid = np.array(list(zip(xv, yv)))
grid = torch.from_numpy(grid).to(torch.float32).to(device)

reconstructions = vae.decoder(grid).detach().cpu()
# plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()

fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    ax = fig.add_subplot(grid_size, grid_size, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :].permute(1, 2, 0), cmap="Greys")