In [8]:
## Standard libraries
import os
import json
import math
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

dataset_path = os.getcwd() + "/data"
checkpoint_path = os.getcwd() + "/saved_models"

pl.seed_everything(42)

device = torch.device("mps:0") if torch.backends.mps.is_available() else torch.device("cpu")

  set_matplotlib_formats('svg', 'pdf') # For export
Global seed set to 42


In [10]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/"
# Files to download
pretrained_files = ["cifar10_64.ckpt", "cifar10_128.ckpt", "cifar10_256.ckpt", "cifar10_384.ckpt"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(checkpoint_path, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(checkpoint_path, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_64.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_128.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_256.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial9/cifar10_384.ckpt...


In [11]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

train_dataset = CIFAR10(root = dataset_path, train = True, transform = transform, download = True)

pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=dataset_path, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

def get_train_images(num):
    return torch.stack([train_dataset[i][0] for i in range(num)], dim=0)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/joesh/deep_autoencoders/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /Users/joesh/deep_autoencoders/data/cifar-10-python.tar.gz to /Users/joesh/deep_autoencoders/data


Global seed set to 42


Files already downloaded and verified


Implement the encoder. The encoder consists of a deep convolutional network where the image is scaled down layer by layer using strided convolutions. After the image is downscaled three times, the features are then flattened and linear layers are applied. 

Given small size of model neglect batch normalisation, as want encodings of each images to be independent of other images. Other normalisation techniques such as instance and layer normalisation can be used.

In [21]:
class Encoder(nn.Module):
    def __init__(self,
                 num_input_channels : int,
                 base_channels_size: int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
                 
                 super().__init__()
                 c_hid = base_channels_size
                 self.net = nn.Sequential(
                    nn.Conv2d(num_input_channels, c_hid, kernel_size = 3, padding = 1, stride = 2),
                    act_fn(),
                    nn.Conv2d(c_hid, c_hid, kernel_size = 3, padding = 1), 
                    act_fn(),
                    nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
                    act_fn(),
                    nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
                    act_fn(),
                    nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
                    act_fn(),
                    nn.Flatten(), # Image grid to single feature vector
                    nn.Linear(2*16*c_hid, latent_dim)
                 )
    
    def forward(self, x):
        return self.net(x)

Decoder is a mirrored, flipped version of the encoder but the strided convolutions are replaced by transposed convolutions to upscale features. 

In [22]:
class Decoder(nn.Module):
        def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):

                 super().__init__()
                 c_hid = base_channel_size
                 self.linear = nn.Sequential(
                    nn.Linear(latent_dim, 2 * 16 * c_hid),
                    act_fn()
                 )
                 
                 self.net = nn.Sequential(
                    nn.ConvTranspose2d(2 * c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
                    act_fn(),
                    nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding = 1),
                    act_fn(),
                    nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding = 1, padding = 1, stride = 2),
                    act_fn(),
                    nn.ConvTranspose2d(c_hid, c_hid, kernel_size=3, padding = 1),
                    act_fn(),
                    nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2),
                    nn.Tanh()
                 )

        def forward(self, x):
            x = self.linear(x)
            x = x.reshape(x.shape[0], -1, 4, 4)
            x = self.net(x)
            return x

In [24]:
class AutoEncoder(pl.LightningModule):

    def __init__(self,
                 base_channel_size: int,
                 latent_dim: int,
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 3,
                 width: int = 32,
                 height: int = 32):

                 super().__init__()

                 self.save_hyperparameters()

                 self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
                 self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)

                 self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):

        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """
        Given a btach of images, this function returns the reconstruction loss which is the MSE in
        this instance
        """

        x, _ = batch
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction='none')
        loss = loss.sum( dim = [1,2,3]).mean(dim=[0]) #calculates the sum of the loss along the dimensions 1,2,3 (batch_size, loss) from there compute mean
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = 1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor = 0.2,
            patience = 20, 
            min_lr = 5e-5
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
    

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('test_loss', loss)



In [25]:
def train_cifar(latent_dim):
    trainer = pl.Trainer(default_root_dir=os.path.join(checkpoint_path, f"cifar10_{latent_dim}"),
                         accelerator="gpu" if str(device).startswith("mps") else "cpu",
                         devices = 1,
                         max_epochs = 500,
                         callbacks=[ModelCheckpoint(save_weights_only=True),
                         LearningRateMonitor("epoch")])
    
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric = None

    pretrained_filename = os.path.join(checkpoint_path, f"cifar10_{latent_dim}.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = AutoEncoder.load_from_checkpoint(pretrained_filename)
    else:
        model = AutoEncoder(base_channel_size=32, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result



In [26]:
model_dict = {}
for latent_dim in [64, 128, 256, 384]:
    model_ld, result_ld = train_cifar(latent_dim)
    model_dict[latent_dim] = {"model": model_ld, "result": result_ld}

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/joesh/deep_autoencoders/saved_models/cifar10_64/lightning_logs

  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 168 K  | [2, 3, 32, 32] | [2, 64]       
1 | decoder | Decoder | 168 K  | [2, 64]        | [2, 3, 32, 32]
----------------------------------------------------------------------
337 K     Trainable params
0         Non-trainable params
337 K     Total params
1.348     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]