# 👖 Autoencoders on Fashion MNIST

In this notebook, we'll walk through the steps required to train your own autoencoder on the fashion MNIST dataset.

In [None]:
import os, sys
from dotenv import load_dotenv

load_dotenv()
python_path = os.getenv('PYTHONPATH')
data_path = os.getenv('DATA_PATH')
if python_path:
    for path in python_path.split(os.pathsep):
        if path not in sys.path:
            sys.path.append(path)

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

from tqdm import tqdm, trange

from notebooks.pt_utils import display, Trainer

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 100
BUFFER_SIZE = 1000
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
EPOCHS = 3
NUM_WOERKERS = 24

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. Prepare the data <a name="prepare"></a>

In [None]:
class AutoEncDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, _ = self.dataset[idx]
        return image, image

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Pad(padding=2)
])
train_dataset = datasets.FashionMNIST(root=data_path, train=True, transform=transform)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])
ae_train_dataset, ae_val_dataset = AutoEncDataset(train_dataset), AutoEncDataset(val_dataset)

test_dataset = datasets.FashionMNIST(root=data_path, train=False, transform=transform)
ae_test_dataset = AutoEncDataset(test_dataset)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')

ae_train_loader = DataLoader(ae_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
ae_val_loader = DataLoader(ae_val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')
ae_test_loader = DataLoader(ae_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WOERKERS, pin_memory=True, pin_memory_device='cuda')

In [None]:
images, labels = next(iter(train_loader))
display(images)
print(labels[:10])

In [None]:
print(f'image size: {images[0].size()}, min: {torch.min(images[0]).item()}, max: {torch.max(images[0]).item()}')

## 2. Build the autoencoder <a name="build"></a>

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_size, latent_size):
        c, w, h = input_size
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=128*4*4, out_features=EMBEDDING_DIM)
        )

        self.decoder = nn.Sequential(
            nn.Linear(in_features=EMBEDDING_DIM, out_features=128*4*4),
            nn.Unflatten(dim=1, unflattened_size=(128, 4, 4)),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding='same'),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

autoencoder = AutoEncoder((1, 32, 32), EMBEDDING_DIM).to(device)

## 3. Train the autoencoder <a name="train"></a>

In [None]:
loss_fn = F.binary_cross_entropy
pred_fn = lambda i: i
optimizer = torch.optim.Adam(params=autoencoder.parameters())

trainer = Trainer(
    model=autoencoder,
    optimizer=optimizer,
    train_loader=ae_train_loader,
    val_loader=ae_val_loader,
    loss_fn=loss_fn,
    pred_fn=pred_fn,
    device=device,
)

In [None]:
history = trainer.fit(epochs=EPOCHS)

In [None]:
# # Save the final models
# autoencoder.save("./models/autoencoder")
# encoder.save("./models/encoder")
# decoder.save("./models/decoder")

## 4. Reconstruct using the autoencoder <a name="reconstruct"></a>

In [None]:
ex_loader = DataLoader(test_dataset, batch_size=5000, shuffle=True, num_workers=24)
example_images, example_labels = next(iter(ex_loader))

In [None]:
predictions = autoencoder(example_images.to(device)).detach().cpu()

print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(predictions)

## 5. Embed using the encoder <a name="encode"></a>

In [None]:
# Encode the example images
embeddings = autoencoder.encoder(example_images.to(device)).detach().cpu()

In [None]:
# Some examples of the embeddings
print(embeddings[:10])

In [None]:
# Show the encoded points in 2D space
figsize = 8

plt.figure(figsize=(figsize, figsize))
plt.scatter(embeddings[:, 0], embeddings[:, 1], c="black", alpha=0.5, s=3)
plt.show()

In [None]:
# Colour the embeddings by their label (clothing type - see table)
# example_labels = y_test[:n_to_predict]

figsize = 8
plt.figure(figsize=(figsize, figsize))
plt.scatter(
    embeddings[:, 0],
    embeddings[:, 1],
    cmap="rainbow",
    c=example_labels,
    alpha=0.8,
    s=3,
)
plt.colorbar()
plt.show()

## 6. Generate using the decoder <a name="decode"></a>

In [None]:
# Get the range of the existing embeddings
mins, maxs = np.min(embeddings.numpy(), axis=0), np.max(embeddings.numpy(), axis=0)

# Sample some points in the latent space
grid_width, grid_height = (6, 3)
sample = np.random.uniform(
    mins, maxs, size=(grid_width * grid_height, EMBEDDING_DIM)
)
# Decode the sampled points
sample = torch.from_numpy(sample).to(torch.float32)
reconstructions = autoencoder.decoder(sample.to(device)).detach().cpu()

In [None]:
# Draw a plot of...
figsize = 8
plt.figure(figsize=(figsize, figsize))

# ... the original embeddings ...
plt.scatter(embeddings[:, 0], embeddings[:, 1], c="black", alpha=0.5, s=2)

# ... and the newly generated points in the latent space
plt.scatter(sample[:, 0], sample[:, 1], c="#00B0F0", alpha=1, s=40)
plt.show()

# Add underneath a grid of the decoded images
fig = plt.figure(figsize=(figsize, grid_height * 2))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(grid_width * grid_height):
    ax = fig.add_subplot(grid_height, grid_width, i + 1)
    ax.axis("off")
    ax.text(
        0.5,
        -0.35,
        str(np.round(sample[i, :].numpy(), 1)),
        fontsize=10,
        ha="center",
        transform=ax.transAxes,
    )
    ax.imshow(reconstructions[i, :, :].permute(1, 2, 0), cmap="Greys")

In [None]:
# Colour the embeddings by their label (clothing type - see table)
figsize = 12
grid_size = 15
plt.figure(figsize=(figsize, figsize))
plt.scatter(
    embeddings[:, 0],
    embeddings[:, 1],
    cmap="rainbow",
    c=example_labels,
    alpha=0.8,
    s=300,
)
plt.colorbar()

x = np.linspace(min(embeddings[:, 0]), max(embeddings[:, 0]), grid_size)
y = np.linspace(max(embeddings[:, 1]), min(embeddings[:, 1]), grid_size)
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
# grid = np.array(list(zip(xv, yv)))
grid = torch.from_numpy(np.array(list(zip(xv, yv)))).to(torch.float32).to(device)

reconstructions = autoencoder.decoder(grid).detach().cpu()
# plt.scatter(grid[:, 0], grid[:, 1], c="black", alpha=1, s=10)
plt.show()

fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    ax = fig.add_subplot(grid_size, grid_size, i + 1)
    ax.axis("off")
    ax.imshow(reconstructions[i, :, :].permute(1, 2, 0), cmap="Greys")