# Variational Autoencoder for Images

- Small CNN-based network
- Training on CIFAR10
- Logging via tensorboard
- [Nice guide](https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed)

In [None]:
# %pip install torch torchvision torcheval torchsummary tensorboard einops

## Define Model

In [None]:
import torch
import torch.nn.functional as F
import einops

class VAE(torch.nn.Module):
    def __init__(
            self, img_size: int=32, input_dim: int=3, hidden_dims: list=[32, 64], latent_dim: int=8
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.latent_dim = latent_dim
        
        # Keep track of the spatial dims for the linear layer.
        # This is required when working with the flattened output of the conv layers.
        # For each conv layer, the spatial dims are halved.        
        # Assumptions: square images, stride=2, padding=1.
        self.spatial_dims_inner_conv = img_size // (2 ** len(hidden_dims))
        print(f'Linear spatial dims: {self.spatial_dims_inner_conv}')

        encoder_layers = []
        in_channels = input_dim
        for h_dim in hidden_dims:
            encoder_layers.append(
                torch.nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1))
            encoder_layers.append(torch.nn.ReLU())
            in_channels = h_dim        
        self.encoder = torch.nn.Sequential(*encoder_layers)            

        # The input size is the number of channels of the last conv layer 
        # times the spatial dims squared.
        self.linear_mu = torch.nn.Linear(self.hidden_dims[-1] * self.spatial_dims_inner_conv**2,
                                         self.latent_dim)
        self.linear_std= torch.nn.Linear(self.hidden_dims[-1] * self.spatial_dims_inner_conv**2,
                                         self.latent_dim)
            

        decoder_layers = []
        in_channels = hidden_dims[-1]
        reversed_hidden_dims = list(reversed(hidden_dims))
        for i in range(len(reversed_hidden_dims) - 1):
            decoder_layers.append(
                torch.nn.ConvTranspose2d(reversed_hidden_dims[i], 
                                         reversed_hidden_dims[i+1],
                                         kernel_size=3, stride=2, padding=1, output_padding=1))
            decoder_layers.append(torch.nn.ReLU())        
        self.decoder = torch.nn.Sequential(*decoder_layers, 
                                           torch.nn.ConvTranspose2d(reversed_hidden_dims[-1], 
                                                                    input_dim,
                                                                    kernel_size=3, 
                                                                    stride=2, 
                                                                    padding=1,
                                                                    output_padding=1),
                                           torch.nn.Tanh())
        self.linear_decoder = torch.nn.Linear(latent_dim, 
                                              hidden_dims[-1] * self.spatial_dims_inner_conv**2)
    
    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.encoder(x)
        # Keep the batch dim unchanged.
        x = einops.rearrange(x, 'b c h w -> b (c h w)')
        mus = self.linear_mu(x)
        log_vars = self.linear_std(x)
        return mus, log_vars
    
    def draw_sample(self, mus: torch.Tensor, log_vars: torch.Tensor):
        eps = torch.randn_like(log_vars)
        # Reparameterization trick
        # Any normal distribution can be constructed by using 
        # a standard normal distribution (epsilon), scaling it 
        # by the standard deviation (sigma) and then shifting by the mean (mu).
        z = mus + log_vars * eps
        return z
    
    def decode(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x_hat = self.linear_decoder(z)
        x_hat = einops.rearrange(x_hat, 'b (c h w) -> b c h w',
                                 c=self.hidden_dims[-1],
                                 h=self.spatial_dims_inner_conv,
                                 w=self.spatial_dims_inner_conv)
        x_hat = self.decoder(x_hat)
        return x_hat

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Mu and sigma represent the parameters of n = mu.shape[-1] normal distributions.
        mus, log_vars = self.encode(x)
        # Of these n normal distributions, we draw n samples.
        z = self.draw_sample(mus, log_vars)
        x_hat = self.decode(z)
        return x_hat, mus, log_vars

def elbo_loss(
        x: torch.Tensor, x_hat: torch.Tensor, mus: torch.Tensor, log_var: torch.Tensor
) -> torch.Tensor:
    """Calculates the ELBO loss.

    Parameters
    ----------
    x : torch.Tensor
        Z-score normalized input images.
    x_hat : torch.Tensor
        Reconstruction of the input images.
    mus : torch.Tensor
        Mu values of the latent space.
    log_vars : torch.Tensor
        Sigma values of the latent space.

    Returns
    -------
    torch.Tensor        
    """
    # Rescale the output to [0, 1] to be able to use the MSE loss.
    x_hat = (1 + x_hat) / 2
    mse = F.mse_loss(x_hat, x, reduction='none')
    # We want the distribution of the latent space to be as close as possible to a standard normal distribution.    
    # Taken from https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py#L143C105-L143C105
    # KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
    # Derivation: https://github.com/AntixK/PyTorch-VAE/issues/69
    # The derived formula using the log_var is better, because it allows for the values to be 
    # negative during training. When treating them as Sigma, they would have to be positive,
    # which is hard to enforce.
    d_kl = torch.mean(-0.5 * torch.sum(1 + log_var - mus ** 2 - log_var.exp(), dim = 1), dim = 0)
    beta = 1
    elbo = mse + beta * d_kl
    return elbo.mean()

# from torchsummary import summary
# print(vae)
# summary(vae, (3, 32, 32), batch_size=8192)

## Train

In [None]:
from datetime import datetime
from pathlib import Path

import torch
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics.functional import peak_signal_noise_ratio
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from tqdm import tqdm


def validate(model: torch.nn.Module,
             epoch: int,
             log_writer: SummaryWriter,
             dataloader: torch.utils.data.DataLoader,
             test: bool=False
):
    model.eval()
    with torch.no_grad():
        elbo_losses = []
        psnr_scores = []
        mse_scores = []
        for i, (x, _) in enumerate(dataloader):
            x = x.cuda()
            x_hat, mus, log_vars = model(x)
            elbo_losses.append(elbo_loss(x, x_hat, mus, log_vars))
            psnr_scores.append(peak_signal_noise_ratio(x_hat, x))
            mse_scores.append(F.mse_loss(x_hat, x))
            if i == 0:
                x_grid = torchvision.utils.make_grid(
                                x[:8].unsqueeze(0), nrow=2, normalize=True, scale_each=True)
                x_hat_grid = torchvision.utils.make_grid(
                                x_hat[:8].unsqueeze(0), nrow=2, normalize=True, scale_each=True)
                log_writer.add_images('Input Images', x_grid, epoch)
                log_writer.add_images('Reconstructed Images', x_hat_grid, epoch)
        avg_elbo = torch.stack(elbo_losses).mean()
        avg_mse_score = torch.stack(mse_scores).mean()
        avg_psnr_score = torch.stack(psnr_scores).mean()
        # We have the same loop for validation and testing. Adapt the log names accordingly.
        run_type = 'Test' if test else 'Validation'        
        log_writer.add_scalar(f'{run_type} ELBO', avg_elbo, epoch)
        log_writer.add_scalar(f'{run_type} MSE', avg_mse_score, epoch)
        log_writer.add_scalar(f'{run_type} PSNR', avg_psnr_score, epoch)
        # print(f'    {run_type} metrics, averaged over validation set: '
        #         f'MSE: {avg_mse_score:0.6f}, '
        #         f'PSNR: {avg_psnr_score:0.6f}')


def test(
        model: torch.nn.Module, log_writer: SummaryWriter, batch_size: int=512, num_workers: int=2
):
    log_writer = SummaryWriter()
    transform = T.Compose([T.ToTensor()])
    test_set = CIFAR10(root='./data', download=False, train=False, transform=transform)
    test_dataloader = torch.utils.data.DataLoader(
                    test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    validate(model, 0, log_writer, test_dataloader, test=True)
    print('Testing Done.')


def train(model: torch.nn.Module,
          log_writer: SummaryWriter,
          checkpoint_dir: str,
          pretrained_path: str=None,
          num_epochs: int=1000,
          batch_size: int=512,
          num_workers: int=2,
          val_freq: int=100,
) -> torch.nn.Module:
    """Train the model, log metrics and save checkpoints.

    Parameters
    ----------
    model : torch.nn.Module
        Model to train.
    log_writer : SummaryWriter
        Tensorboard logger.
    checkpoint_dir: str
        Directory to save checkpoints to.
    pretrained_path: str, optional
        Path to a pre-trained checkpoint, by default None
    num_epochs : int, optional
        Number of epochs to train. If there is a pretrained checkpoint given,
        these epochs are added on top of the checkpoint's epochs. By default 1000.
    batch_size : int, optional
        by default 512
    num_workers : int, optional
        by default 2
    val_freq : int, optional
        Validate every *val_freq* epochs., by default 100

    Returns
    -------
    torch.nn.Module
        Trained model.
    """
    checkpoint_freq = 1000
    transform = T.Compose([T.ToTensor()])
    train_set = CIFAR10(root='./data', download=True, train=True, transform=transform)
    val_set = CIFAR10(root='./data', download=False, train=True, transform=transform)
    split = int(len(train_set) * 0.8)
    train_set.data = train_set.data[:split]
    val_set.data = val_set.data[split:]
    train_dataloader = torch.utils.data.DataLoader(
                    train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataloader = torch.utils.data.DataLoader(
                    val_set,batch_size=batch_size, shuffle=False, num_workers=num_workers)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    pre_trained_epoch = 0
    if pretrained_path:
        model.load_state_dict(torch.load(pretrained_path))
        pre_trained_epoch = int(Path(pretrained_path).stem.split('_')[-1]) + 1
    with tqdm(desc=f'Training...',
              total=pre_trained_epoch + num_epochs,
              initial=pre_trained_epoch
    ) as pbar:
        for epoch in range(pre_trained_epoch, pre_trained_epoch + num_epochs):
            model.train()
            epoch_losses = []
            for i, (x, _) in enumerate(train_dataloader):
                x = x.cuda()
                optimizer.zero_grad()
                x_hat, mus, log_vars = model(x)
                loss = elbo_loss(x, x_hat, mus, log_vars)
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss)
            log_writer.add_scalar('Train ELBO loss', torch.stack(epoch_losses).mean(), epoch)
            if epoch % 10 == 0:
                pbar.set_postfix({'loss': f'{loss.item():0.6f}'})
            if epoch % val_freq == 0:
                validate(model, epoch, log_writer, val_dataloader)
            if epoch % checkpoint_freq == 0:
                save_path = Path(checkpoint_dir) / f'{datetime.now()}_vae_{epoch}.pt'
                torch.save(model.state_dict(), save_path)                            
            pbar.update(1)
    log_writer.flush()
    log_writer.close()
    save_path = Path(checkpoint_dir) / f'{datetime.now()}_vae_{epoch}.pt'
    torch.save(model.state_dict(), save_path)
    print('Training Done.')
    model.eval()
    return model

def run_training(pretrained_file: str=None,
                 checkpoint_dir: str='/home/jo/git/vae-playground/data/checkpoints/'
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    log_writer = SummaryWriter()
    vae = VAE().to(device)    
    pretrained_path = checkpoint_dir + pretrained_file if pretrained_file else None
    # train(vae, num_epochs=2, val_freq=1)
    vae = train(vae, 
                log_writer=log_writer,
                checkpoint_dir=checkpoint_dir,
                pretrained_path=pretrained_path,
                num_epochs=10000,
                val_freq=100)
    test(vae, log_writer)


run_training(pretrained_file='2023-09-06 11:42:33.149736_vae_1999.pt')
