# Variational AutoEncoders

Hi! Today we are going to learn about variationals autoencoders. We'll code them to encode handwritten numbers and restore them from the compact vector representation.

In [None]:
!pip install catalyst

In [None]:
from catalyst.utils import set_global_seed, get_device

In [None]:
set_global_seed(42)
device = get_device()

We'll work with `MNIST` dataset. Download it, show examples of the writting and prepare the dataset to be loaded into models.

In [None]:
from catalyst.contrib.datasets import mnist


train = mnist.MNIST('.', train=True, download=True)
valid = mnist.MNIST('.', train=False, download=True)

In [None]:
import matplotlib.pyplot as plt


_, axs = plt.subplots(4, 4, figsize=(10, 10))

for i in range(16):
    axs[i // 4][i % 4].imshow(train[100 * i + i][0])

In [None]:
import torch
import torch.nn as nn

In [None]:
from catalyst.utils import get_loader


batch_size = 256
num_workers = 4

def transform(x):
    image = torch.FloatTensor(x['image'])
    image = torch.where(image > 127, torch.ones(image.shape), torch.zeros(image.shape))
    return {'image': image, 'targets': x['targets']}


train_data_loader = get_loader(
    train,
    open_fn=lambda x : {'image': x[0].reshape(1, 28, 28), 'targets': x[1]},
    dict_transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    sampler=None,
    drop_last=True,
)

valid_data_loader = get_loader(
    valid,
    open_fn=lambda x : {'image': x[0].reshape(1, 28, 28), 'targets': x[1]},
    dict_transform=transform,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
    sampler=None,
    drop_last=False,
)

A variational autoencoder consists of two parts: encoder and decoder. The encoder shrinks objects into some vector. The decoder generates an proximate an 'image' of object. In our case, objects are images. We will use CNNs for encoding images and UpScale Convolution operations for decoding.

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_size=2):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, stride = 2, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 16, kernel_size=3, stride = 2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Flatten()
        )
        self.latent_space = nn.Linear(16*7*7, 2 * latent_size)
        
        self.latent_size = latent_size
        
    def forward(self, images):
        features = self.feature_extractor(images)
        latent = self.latent_space(features)
        return latent[:, :self.latent_size], latent[:, self.latent_size:]

In [None]:
from catalyst.contrib.nn.modules import Lambda


class Decoder(nn.Module):
    def __init__(self, image_size=(28, 28), latent_size=2):
        super().__init__()
        
        self.image_size = image_size
        self.latent_size = latent_size
        
        self.map_generator = nn.Sequential(
            nn.Linear(latent_size, 16 * 49),
            Lambda(lambda x: x.view(x.size(0), 16, 7, 7)),
        )
        self.deconv = nn.Sequential(
            self.make_up_layer_(16, 8), # 7 -> 14
            self.make_up_layer_(8, 4), # 14 -> 28
        )
            
        self.output = nn.Sequential(
            nn.Conv2d(4, 1, 3, padding=1),
        )
            
    def forward(self, points):
        feature_map = self.map_generator(points)
        feature_map = self.deconv(feature_map)
        return self.output(feature_map)
            
    def make_up_layer_(self, in_channels, out_channels):
        return nn.Sequential(nn.ConvTranspose2d(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            )

Joint the encoder and decoder to create VAE! We have discussed in the lecture about it, and we knew how to train VAE. We need sample points in latent space, pass them forward through the decoder and compare a decoder result with original object. Also we should sample points from some normal distribution, which parameters approach to $(0, I)$.

In [None]:
LOG_SCALE_MAX = 2
LOG_SCALE_MIN = -10

def normal_sample(loc, log_scale):
    scale = torch.exp(0.5 * log_scale)
    return loc + scale * torch.randn_like(scale)


class VAE(nn.Module):
    def __init__(self, image_size=(28, 28), latent_size=2):
        super().__init__()

        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(image_size, latent_size)
        
    def forward(self, images):
        loc, log_scale = self.encoder(images)
        log_scale = torch.clamp(log_scale, LOG_SCALE_MIN, LOG_SCALE_MAX)

        z_ = normal_sample(loc, log_scale) if self.training else loc
        x_ = self.decoder(z_)

        return {
            'decoder_result': x_,
            'loc': loc,
            'log_scale': log_scale
        }

In [None]:
class KLVAELoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, loc, log_scale):
        return (-0.5 * torch.sum(1 + log_scale - loc.pow(2) - log_scale.exp(), dim=1)).mean()

We need to modify `BinaryCrossEntropyLoss` function, because it doesn't work properly with images.

To monitor decoded images, we have to write a new callback function. It will log image into the tensorboard.

In [None]:
from catalyst.core import Callback, CallbackOrder


class LogFigureCallback(Callback):
    def __init__(self):
        super().__init__(CallbackOrder.External)

    def on_epoch_end(self, runner):
        if runner.is_valid_loader:
            logger = runner.loggers['_tensorboard']
            logger = logger.loggers[runner.loader_key]
            logger.add_images(f'image/epoch', torch.sigmoid(runner.batch['decoder_result']))

Create model, criterion, optimizer. Train model!

In [None]:
from catalyst.contrib.nn.optimizers import RAdam


model = VAE()
criterion = {
    'ae': nn.BCEWithLogitsLoss(),
    'kl': KLVAELoss()
}
optimizer = RAdam(model.parameters(), lr=1e-2)

In [None]:
from catalyst import dl


callbacks = [
    dl.CriterionCallback(
        input_key='decoder_result', target_key='image', metric_key='loss_ae', criterion_key='ae',
    ),
    dl.CriterionCallback(
        input_key='loc', target_key='log_scale', metric_key='loss_kl', criterion_key='kl'
    ),
    dl.MetricAggregationCallback(
        metric_key='loss',
        mode='weighted_sum',
        metrics={'loss_ae': 1.0, 'loss_kl': 0.01},
    ),
    LogFigureCallback(),
]

In [None]:
class VAERunner(dl.SupervisedRunner):
    def predict_batch(self, batch):
        prediction = {'image':batch['image'], 'targets':batch['targets']}
        prediction.update(self.model(batch['image'].to(runner.device)))
        return prediction
    
    def handle_batch(self, batch):
        self.batch.update(self.model(batch['image']))


runner = VAERunner()

In [None]:
from datetime import datetime
from pathlib import Path


logdir = Path('logs') / datetime.now().strftime('%Y%m%d-%H%M%S')

%reload_ext tensorboard
%tensorboard --logdir logs

In [None]:
runner.train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    loaders={'train': train_data_loader, 'valid': valid_data_loader},
    callbacks=callbacks,
    num_epochs=1,
    verbose=True,
    logdir=logdir,
    valid_loader='valid',
    valid_metric='loss',
    load_best_on_end = True
)

One of the main feature of VAE it's a generating new objects. We can do this by mixing latent representation of objects.

In [None]:
test_data = next(iter(valid_data_loader))
test_data['targets']

In [None]:
model.eval()
locs, _ = model.encoder(test_data['image'].to(device))

In [None]:
import numpy as np


def plot_transition(i, j):
    _, ax = plt.subplots(1, 11, figsize=(15, 5))
    
    line = np.linspace(0, 1, 11)
    for k in range(0, 11):
        point = line[k] * locs[j] + (1 - line[k]) * locs[i]
        decoded = torch.sigmoid(model.decoder(point.unsqueeze(0).to(device)).squeeze())
        ax[k].imshow(decoded.squeeze().cpu().detach().numpy())

In [None]:
%matplotlib inline
plot_transition(0, -3)

We can enhance generated images by many ways. And we choose to add classification task. The model will classify object based on the corresponding latent representation.

In [None]:

class VAEClassify(nn.Module):
    def __init__(self, num_classes=10, image_size=(28, 28), latent_size=10):
        super().__init__()

        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(image_size, latent_size)
        self.clf = nn.Linear(latent_size, num_classes)
        
    def forward(self, images):
        loc, log_scale = self.encoder(images)
        log_scale = torch.clamp(log_scale, LOG_SCALE_MIN, LOG_SCALE_MAX)

        z_ = normal_sample(loc, log_scale) if self.training else loc
        x_ = self.decoder(z_)

        logits = self.clf(z_)
        return {
            'logits': logits, 
            'decoder_result': x_,
            'loc': loc,
            'log_scale': log_scale
        }

In [None]:
from catalyst.contrib.nn.optimizers import RAdam


model = VAEClassify()
criterion = {
    'ce': nn.CrossEntropyLoss(),
    'ae': nn.BCEWithLogitsLoss(),
    'kl': KLVAELoss()
}
optimizer = RAdam(model.parameters(), lr=1e-2)

In [None]:
callbacks = [
    dl.CriterionCallback(
        input_key='decoder_result', target_key='image', metric_key='loss_ae', criterion_key='ae',
    ),
    dl.CriterionCallback(
        input_key='loc', target_key='log_scale', metric_key='loss_kl', criterion_key='kl'
    ),
    dl.CriterionCallback(
        input_key='logits', target_key='targets', metric_key='loss_ce', criterion_key='ce',
    ),
    dl.MetricAggregationCallback(
        metric_key='loss',
        mode='weighted_sum',
        metrics={'loss_ae': 1.0, 'loss_kl': 0.01, 'loss_ce': 1.0},
    ),
    dl.AccuracyCallback(input_key='logits', target_key='targets'),
    LogFigureCallback(),
]

In [None]:
runner = VAERunner()

runner.train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    loaders={'train': train_data_loader, 'valid': valid_data_loader},
    callbacks=callbacks,
    num_epochs=10,
    verbose=True,
    logdir=Path('logs') / datetime.now().strftime('%Y%m%d-%H%M%S'),
    valid_loader='valid',
    valid_metric='loss',
    load_best_on_end = True
)

Let's compare results with the usual VAE.

In [None]:
model.eval()
locs, _ = model.encoder(test_data['image'].to(device))

In [None]:
plot_transition(0, -3)

Let's check how our model restore noised objects. The model aren't trained to restore, but it can do this very well.

In [None]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data['image'][k]
    ax[k // 6][k % 6].imshow(image.squeeze().cpu().detach().numpy())

In [None]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data['image'][k]
    noise = torch.rand(image.size())
    ax[k // 6][k % 6].imshow((image + noise).squeeze().cpu().detach().numpy())

In [None]:
_, ax = plt.subplots(2, 6, figsize=(10, 4))
    
for k in range(0, 12):
    image = test_data['image'][k]
    noise = torch.rand(image.size())*0.2
    point, _ = model.encoder((image + noise).unsqueeze(0).to(device))
    decoded = torch.sigmoid(model.decoder(point.unsqueeze(0).to(device)).squeeze())
    ax[k // 6][k % 6].imshow(decoded.cpu().detach().numpy())

In the end, let's look at the latent space. We choose 2D plain space, so it's easy to plot the points.

In [None]:
predictions = {'image': [], 'loc': [], 'target': []}

for pred in runner.predict_loader(loader=valid_data_loader):
    predictions['image'].extend(o.reshape(28, 28) for o in pred['image'].numpy())
    predictions['loc'].extend(i for i in pred['loc'].cpu().numpy())
    predictions['target'].extend(i for i in pred['targets'].numpy())

In [None]:
predictions['x'] = [o[0] for o in predictions['loc']]
predictions['y'] = [o[1] for o in predictions['loc']]

In [None]:
import seaborn as sns

sns.set()

_, ax = plt.subplots(1, 1, figsize=(10, 10))
sns.scatterplot(x='x', y='y', hue='target', data=predictions, ax=ax, legend='full')