# Variational AutoEncoders

Hi! Today we are going to learn about variationals autoencoders. We'll code them to encode handwritten numbers and restore them from the compact vector representation.

In [3]:
# !pip install -U catalyst

In [4]:
from catalyst.utils import set_global_seed, get_device
from datetime import datetime
from pathlib import Path

import tkinter
import matplotlib
matplotlib.use('TkAgg')

set_global_seed(42)
device = get_device()

In [5]:
import catalyst
catalyst.__version__

'21.04.2'

We'll work with `MNIST` dataset. Download it, show examples of the writting and prepare the dataset to be loaded into models.

In [6]:
from catalyst.contrib.datasets import mnist


train = mnist.MNIST(".", train=True, download=True)
valid = mnist.MNIST(".", train=False, download=True)

In [7]:
import matplotlib.pyplot as plt


_, axs = plt.subplots(4, 4, figsize=(6.4 * 1.5, 4.8 * 1.5))

for i in range(16):
    axs[i // 4][i % 4].imshow(train[100 * i + i][0])
plt.show()

In [8]:
import torch
import torch.nn as nn

In [177]:
import numpy as np
import typing as tp
from catalyst.utils import get_loader


batch_size = 256
num_workers = 0


def transform(x: np.array) -> tp.Dict[str, torch.Tensor]:
    image = torch.FloatTensor(x["image"])
    image = torch.where(image > 127, 1, 0).float() # Use torch.where, to replace 256 values to ones or zeros
    return {'image': image, "targets": x["targets"]}


train_data_loader = get_loader(
    train,
    open_fn=lambda x : {'image': x[0].reshape(1, 28, 28), 'targets': x[1]},
    dict_transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    sampler=None,
    drop_last=True,
)

valid_data_loader = get_loader(
    valid,
    open_fn=lambda x : {'image': x[0].reshape(1, 28, 28), 'targets': x[1]},
    dict_transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
    sampler=None,
    drop_last=False,
)

In [178]:
next(iter(valid_data_loader))['image'].shape

torch.Size([256, 1, 28, 28])

A variational autoencoder consists of two parts: encoder and decoder. The encoder shrinks objects into some vector. The decoder generates an proximate an "image" of object. In our case, objects are images. We will use CNNs for encoding images and UpScale Convolution operations for decoding.

In [81]:
class Encoder(nn.Module):
    def __init__(self, latent_size: int = 2):
        super().__init__()

        # Create encoder model!
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(4, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten()
        )
        self.latent_space = nn.Linear(64, 2 * latent_size)
        
        self.latent_size = latent_size
        
    def forward(self, images: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
        features = self.feature_extractor(images)
        latent = self.latent_space(features)
        return latent[:, :self.latent_size], latent[:, self.latent_size:]

In [82]:
from catalyst.contrib.nn.modules import Lambda


class Decoder(nn.Module):
    def __init__(
        self,
        image_size: tp.Tuple[int, int] = (28, 28),
        latent_size: int = 2
    ):
        super().__init__()
        
        self.image_size = image_size
        self.latent_size = latent_size

        # Create Decoder model!
        self.map_generator = nn.Sequential(
            nn.Linear(latent_size, 64 * 49),
            Lambda(lambda x: x.view(x.size(0), 64, 7, 7)),
        )
        self.deconv = nn.Sequential(
            self.make_up_layer_(64, 16), # 7 -> 14
            self.make_up_layer_(16, 4), # 14 -> 28
        )
        self.output = nn.Sequential(
            nn.Conv2d(4, 1, 3, padding=1),
        )
            
    def forward(self, points: torch.Tensor) -> torch.Tensor:
        feature_map = self.map_generator(points)
        feature_map = self.deconv(feature_map)
        return self.output(feature_map)
            
    def make_up_layer_(self, in_channels: int, out_channels: int) -> torch.Tensor:
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 
                               kernel_size=3, padding=1,
                               output_padding=1, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(0.1),
            nn.LeakyReLU(0.01)
        )

Joint the encoder and decoder to create VAE! We have discussed in the lecture about it, and we knew how to train VAE. We need sample points in latent space, pass them forward through the decoder and compare a decoder result with original object. Also we should sample points from some normal distribution, which parameters approach to $(0, I)$.

In [83]:
LOG_SCALE_MAX = 2
LOG_SCALE_MIN = -10

def normal_sample(loc: torch.Tensor, log_scale: torch.Tensor) -> torch.Tensor:
    scale = torch.exp(0.5 * log_scale)
    return loc + scale * torch.randn_like(scale)


class VAE(nn.Module):
    def __init__(self, image_size: tp.Tuple[int, int] = (28, 28), latent_size: int = 2):
        super().__init__()

        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(image_size, latent_size)
        
    def forward(self, images: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
        loc, log_scale = self.encoder(images) # get loc and scale for sampling
        log_scale = torch.clamp(log_scale, LOG_SCALE_MIN, LOG_SCALE_MAX)

        z_ = normal_sample(loc, log_scale) if self.training else loc
        x_ = self.decoder(z_) # recreate object from z_

        return {
            "decoder_result": x_,
            "loc": loc,
            "log_scale": log_scale
        }

In [84]:
class KLVAELoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, loc: torch.Tensor, log_scale: torch.Tensor) -> torch.Tensor:
        return (-0.5 * torch.sum(1 + log_scale - loc.pow(2) - log_scale.exp(), dim=1)).mean()

We need to modify `BinaryCrossEntropyLoss` function, because it doesn't work properly with images.

In [85]:
class ImageCELoss(nn.BCEWithLogitsLoss):
    def __init__(self):
        super().__init__()
        
    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        input = input.view(-1) # reshape input to (batch_size * ...)
        target = target.view(-1) # reshape target to (batch_size * ...) [bs, C_0, .., C_n]
        return super().forward(input, target)

To monitor decoded images, we have to write a new callback function. It will log image into the tensorboard.

In [86]:
from catalyst import dl
from catalyst.core import Callback, CallbackOrder


class LogFigureCallback(Callback):
    def __init__(self):
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, runner: dl.Runner):
        if runner.is_valid_loader:
            tb_callback = runner.loggers["tensorboard"]
            logger = tb_callback.loggers[runner.loader_key]
            decoder_result = runner.output["decoder_result"]
            logger.add_images(
                "image/epoch", 
                torch.sigmoid(decoder_result), # create image from decoder result
#                 global_step=runner.epoch,
            )

Create model, criterion, optimizer. Train model!

In [136]:
from catalyst.contrib.nn.optimizers import RAdam


model = VAE()
criterion = {
    "ae": ImageCELoss(),
    "kl": KLVAELoss()
}
optimizer = RAdam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

In [137]:
x = torch.ones((256, 1, 28, 28))

out = model(x)

In [138]:
callbacks = [
    dl.CriterionCallback(
        input_key="decoder_result", target_key="features", metric_key="loss_ae", criterion_key="ae",
    ),
    dl.CriterionCallback(
        input_key='log_scale', target_key="loc", metric_key="loss_kl", criterion_key="kl"
    ),
    dl.MetricAggregationCallback(
        metric_key="loss",
        mode="weighted_sum",
        metrics={"loss_ae": 1.0, "loss_kl": 0.01},
    ),
    LogFigureCallback(),
    dl.SchedulerCallback(),
    dl.CheckpointCallback(logdir=Path("logs") / datetime.now().strftime("%Y%m%d-%H%M%S"),
                          loader_key="valid", metric_key="loss", minimize=True),
]

In [139]:
class VAERunner(dl.SupervisedRunner):
    def predict_batch(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
        prediction = {"features": batch["image"], "targets": batch["targets"]}
        prediction.update(self.model(batch["image"]))
        return prediction
    
    def handle_batch(self, batch: tp.Dict[str, torch.Tensor]):
        self.output = self.model(batch["image"])
        
        self.batch = {
            'features': batch["image"],
            "targets": batch["targets"],
            'decoder_result': self.output['decoder_result'],
            'loc': self.output['loc'],
            'log_scale': self.output['log_scale'],
        }
        
runner = VAERunner(input_key='images')

In [115]:
logdir = Path("logs") / datetime.now().strftime("%Y%m%d-%H%M%S")

%reload_ext tensorboard
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 14820), started 0:38:39 ago. (Use '!kill 14820' to kill it.)

In [140]:
runner.train(
    engine=dl.DeviceEngine(device),
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    loggers={"tensorboard": dl.TensorboardLogger(logdir=logdir)},
    loaders={"train": train_data_loader, "valid": valid_data_loader},
    callbacks=callbacks,
    num_epochs=1,
    logdir=logdir,
    load_best_on_end=True,
    verbose=True,
)

1/1 * Epoch (train):   0%|          | 0/234 [00:00<?, ?it/s]

train (1/1) loss: 0.3507111072540283 | loss_ae: 0.3441738784313202 | loss_ae/mean: 0.3441738784313202 | loss_ae/std: 0.14102233656137952 | loss_kl: 0.6537229418754578 | loss_kl/mean: 0.6537229418754578 | loss_kl/std: 1.8657114680712537 | lr: 0.01 | momentum: 0.9


1/1 * Epoch (valid):   0%|          | 0/40 [00:00<?, ?it/s]

valid (1/1) loss: 0.23700270056724548 | loss_ae: 0.2275197058916092 | loss_ae/mean: 0.2275197058916092 | loss_ae/std: 0.010229723740033607 | loss_kl: 0.9483001232147217 | loss_kl/mean: 0.9483001232147217 | loss_kl/std: 0.07174206305925224 | lr: 0.01 | momentum: 0.9
* Epoch (1/1) lr: 0.01 | momentum: 0.9
Top best models:
logs\20210511-132850/train.1.pth	0.2370


One of the main feature of VAE it's a generating new objects. We can do this by mixing latent representation of objects.

In [141]:
test_data = next(iter(valid_data_loader))
test_data["targets"]

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0,
        2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4,
        1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4, 8, 7, 3, 9, 7, 4, 4, 4, 9, 2,
        5, 4, 7, 6, 7, 9, 0, 5, 8, 5, 6, 6, 5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1,
        7, 1, 8, 2, 0, 2, 9, 9, 5, 5, 1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5,
        1, 4, 4, 7, 2, 3, 2, 7, 1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1,
        0, 9, 0, 3, 1, 6, 4, 2, 3, 6, 1, 1, 1, 3, 9, 5, 2, 9, 4, 5, 9, 3, 9, 0,
        3, 6, 5, 5, 7, 2, 2, 7, 1, 2, 8, 4, 1, 7, 3, 3, 8, 8, 7, 9, 2, 2, 4, 1,
        5, 9, 8, 7, 2, 3, 0, 4, 4, 2, 4, 1, 9, 5, 7, 7])

In [142]:
model.eval()
locs, _ = model.encoder(test_data["image"].to(device)) # get model prediction on test_data

In [143]:
import numpy as np


def plot_transition(i: int, j: int):
    _, ax = plt.subplots(1, 11, figsize=(15, 5))
    
    line = np.linspace(0, 1, 11)
    for k in range(0, 11):
        point = line[k] * locs[j] + (1 - line[k]) * locs[i]
        decoded = model.decoder(point.unsqueeze(0).to(device)).squeeze() # create image from point
        ax[k].imshow(torch.sigmoid(decoded).detach().cpu().numpy().squeeze()) # plot decoded!
    plt.show()

In [144]:
plot_transition(0, -3)

We can enhance generated images by many ways. And we choose to add classification task. The model will classify object based on the corresponding latent representation.

In [145]:
class VAEClassify(nn.Module):
    def __init__(
        self,
        num_classes: int = 10,
        image_size: tp.Tuple[int, int] = (28, 28),
        latent_size: int = 10,
    ):
        super().__init__()

        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(image_size, latent_size)
        self.clf = nn.Linear(latent_size, num_classes)
        
    def forward(self, images: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
        loc, log_scale = self.encoder(images)
        log_scale = torch.clamp(log_scale, LOG_SCALE_MIN, LOG_SCALE_MAX)

        z_ = normal_sample(loc, log_scale) if self.training else loc
        x_ = self.decoder(z_)

        logits = self.clf(z_)
        return {
            "logits": logits, 
            "decoder_result": x_,
            "loc": loc,
            "log_scale": log_scale
        }

In [146]:
from catalyst.contrib.nn.optimizers import RAdam


model = VAEClassify()
criterion = {
    "ce": nn.CrossEntropyLoss(),
    "ae": ImageCELoss(),
    "kl": KLVAELoss()
}
optimizer = RAdam(model.parameters(), lr=1e-2)

In [154]:
callbacks = [
    dl.CriterionCallback(
        input_key="decoder_result", target_key="features", metric_key="loss_ae", criterion_key="ae",
    ),
    dl.CriterionCallback(
        input_key='log_scale', target_key="loc", metric_key="loss_kl", criterion_key="kl"
    ),
    dl.CriterionCallback(
        input_key="logits", target_key="targets", metric_key="loss_ce", criterion_key="ce",
    ),
    dl.MetricAggregationCallback(
        metric_key="loss",
        mode="weighted_sum",
        metrics={"loss_ae": 1.0, "loss_kl": 0.01, "loss_ce": 1.0},
    ),
    dl.AccuracyCallback(input_key="logits", target_key="targets"),
    LogFigureCallback(),
]

In [185]:
class VAERunner(dl.SupervisedRunner):
    def predict_batch(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
        predict  = self.model(batch["image"].to(self.device))
        prediction = {
            'features': batch["image"],
            "targets": batch["targets"],
            'decoder_result': predict['decoder_result'],
            'loc': predict['loc'],
            'log_scale': predict['log_scale'],
            'logits': predict['logits']
        }
        return prediction
    
    def handle_batch(self, batch: tp.Dict[str, torch.Tensor]):
        self.output = self.model(batch["image"])
        
        self.batch = {
            'features': batch["image"],
            "targets": batch["targets"],
            'decoder_result': self.output['decoder_result'],
            'loc': self.output['loc'],
            'log_scale': self.output['log_scale'],
            'logits': self.output['logits']
        }
        
runner = VAERunner(input_key='images')

In [186]:
runner.train(
    engine=dl.DeviceEngine(device),
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    loggers={"tensorboard": dl.TensorboardLogger(logdir=logdir)},
    loaders={"train": train_data_loader, "valid": valid_data_loader},
    callbacks=callbacks,
    num_epochs=1,
    logdir=logdir,
    load_best_on_end=True,
    verbose=True,
)

1/1 * Epoch (train):   0%|          | 0/234 [00:00<?, ?it/s]

train (1/1) accuracy: 0.9724559187889099 | accuracy/std: 0.010148449210692417 | accuracy01: 0.9724559187889099 | accuracy01/std: 0.010148449210692417 | loss: 0.3903220295906067 | loss_ae: 0.1972823143005371 | loss_ae/mean: 0.1972823143005371 | loss_ae/std: 0.004739165075189417 | loss_ce: 0.0954778641462326 | loss_ce/mean: 0.0954778641462326 | loss_ce/std: 0.02652218401901059 | loss_kl: 9.756183624267578 | loss_kl/mean: 9.756183624267578 | loss_kl/std: 0.5093581846486388 | lr: 0.01 | momentum: 0.9


1/1 * Epoch (valid):   0%|          | 0/40 [00:00<?, ?it/s]

valid (1/1) accuracy: 0.98089998960495 | accuracy/std: 0.01270346459834009 | accuracy01: 0.98089998960495 | accuracy01/std: 0.01270346459834009 | loss: 0.33337920904159546 | loss_ae: 0.18005740642547607 | loss_ae/mean: 0.18005740642547607 | loss_ae/std: 0.007082436971827612 | loss_ce: 0.06260063499212265 | loss_ce/mean: 0.06260063499212265 | loss_ce/std: 0.034774403623831246 | loss_kl: 9.072114944458008 | loss_kl/mean: 9.072114944458008 | loss_kl/std: 0.21327841851032248 | lr: 0.01 | momentum: 0.9
* Epoch (1/1) lr: 0.001 | momentum: 0.9
Top best models:
logs\20210511-132334\checkpoints/train.1.pth	1.0000


Let's compare results with the usual VAE.

In [187]:
model.eval()
locs, _ = model.encoder(test_data["image"].to(device))

In [188]:
plot_transition(0, -3)

Let's check how our model restore noised objects. The model aren't trained to restore, but it can do this very well.

In [189]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data["image"][k]
    ax[k // 6][k % 6].imshow(image.squeeze().cpu().detach().numpy())
plt.show()

In [190]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data["image"][k]
    noise = torch.rand(image.size()) # let's make some noise
    ax[k // 6][k % 6].imshow((image + noise).squeeze().cpu().detach().numpy())
plt.show()

In [191]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data["image"][k]
    noise = torch.rand(image.size()) * 0.4 # let's make some noise
    point, _ = model.encoder((image + noise).unsqueeze(0).to(device)) # get noised objects vector representation from latent space
    decoded = torch.sigmoid(model.decoder(point.unsqueeze(0).to(device)).squeeze()) # decode points
    ax[k // 6][k % 6].imshow(decoded.cpu().detach().numpy()) # plot decoded
plt.show()

In the end, let's look at the latent space. We choose 2D plain space, so it's easy to plot the points.

In [196]:
predictions = {"image": [], "loc": [], "target": []}

for pred in runner.predict_loader(loader=valid_data_loader):
    # Put predicted loc and targets from pred into predictions
    predictions["image"].extend(o.reshape(28, 28) for o in pred["features"].numpy())
    predictions["loc"].extend(i for i in pred["loc"].cpu().numpy())
    predictions["target"].extend(i for i in pred["targets"].numpy())

In [197]:
predictions["x"] = [o[0] for o in predictions["loc"]]
predictions["y"] = [o[1] for o in predictions["loc"]]

In [202]:
import seaborn as sns

sns.set()

# Use sns.scatterplot to show points !
_, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.scatterplot(x="x", y="y", hue="target", data=predictions, ax=ax, legend='full')
plt.show()