# I.4 Variational autoencooder

## Preparation

In [None]:
from copy import deepcopy

import matplotlib.pyplot as plt
from matplotlib.image import imread
from mpl_toolkits import mplot3d
from matplotlib import gridspec
from PIL import Image
import io
import os
from urllib.request import urlopen
from skimage.segmentation import mark_boundaries

from tqdm.notebook import tqdm
import numpy as np
import requests
from scipy.stats import norm
import torch

from sklearn.metrics import classification_report
from torch.utils.tensorboard import SummaryWriter

from torchvision import datasets, transforms
import math

from sklearn.model_selection import ParameterGrid

torch.__version__


In [2]:
%load_ext tensorboard


In [None]:
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    device = 'cuda'
else:
    print("No GPU device!")
    device = 'cpu'


## Data

In [4]:
class GaussianClustersDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=10000, num_clusters=5, dim=10):
        self.data = []
        self.labels = []

        np.random.seed(42)

        for i in range(num_clusters):
            mean = np.random.uniform(-5, 5, size=(dim,)) # cluster center
            cov = np.eye(dim) * np.random.uniform(0.5, 1.5) # diag cov matrix

            samples = np.random.multivariate_normal(mean, cov, num_samples // num_clusters)

            self.data.append(samples)
            self.labels.extend([i] * (num_samples // num_clusters))

        self.data = np.vstack(self.data).astype(np.float32)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


In [None]:
example_dataset = GaussianClustersDataset(num_samples=5000, num_clusters=5, dim=3)
sample_data, sample_labels = example_dataset.data, example_dataset.labels

fig = plt.figure(figsize=(20, 8))
ax = fig.add_subplot(111, projection='3d')

sc = ax.scatter(sample_data[:, 0], sample_data[:, 1], sample_data[:, 2], c=sample_labels, cmap='viridis', alpha=0.5)
plt.colorbar(sc, label="Cluster ID")

ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")
ax.set_zlabel("Feature 3")
plt.title("Gaussian Clusters dataset (3D Visualization)")

plt.show()


## Training code

### General

In [6]:
def train_epoch(train_generator, model, loss_function, optimizer, scheduler=None, callback=None):
    epoch_loss = 0
    total = 0

    for it, batch_of_x in enumerate(train_generator):
        batch_loss = train_on_batch(model, batch_of_x[0], optimizer, loss_function)

        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)

        epoch_loss += batch_loss * len(batch_of_x)
        total += len(batch_of_x)

    scheduler.step()

    return epoch_loss / total


In [7]:
def trainer(count_of_epoch,
            batch_size,
            dataset,
            model,
            loss_function,
            optimizer,
            lr=0.001,
            callback=None):

    optima = optimizer(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optima, gamma=0.95)

    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})

    for it in iterations:
        batch_generator = tqdm(
            torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True),
            leave=False, total=len(dataset) // batch_size + (len(dataset) % batch_size > 0))

        epoch_loss = train_epoch(train_generator=batch_generator,
                    model=model,
                    loss_function=loss_function,
                    optimizer=optima,
                    scheduler=scheduler,
                    callback=callback)

        iterations.set_postfix({'train epoch loss': epoch_loss})


### Specific

In [8]:
def train_on_batch(model, x_batch, optimizer, loss_function):
    model.train()
    optimizer.zero_grad()

    output = model(x_batch.to(model.device))
    loss = model.loss(*output)
    loss.backward()

    optimizer.step()
    return loss.cpu().item()


In [9]:
class callback():
    def __init__(self, writer, dataset, loss_function, experiment_name, delimeter=100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.experiment_name = experiment_name
        self.batch_size = batch_size

        self.dataset = dataset

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar(f"{self.experiment_name}/LOSS/train", loss, self.step)

        if self.step % self.delimeter == 0:

            batch_generator = torch.utils.data.DataLoader(dataset = self.dataset,
                                                          batch_size=self.batch_size)

            pred = []
            real = []
            test_loss = 0
            model.eval()
            for it, (x_batch, _) in enumerate(batch_generator):
                x_batch = x_batch.to(model.device)

                output = model(x_batch)[0]

                test_loss += self.loss_function(output, x_batch).cpu().item()*len(x_batch)

                pred.extend(torch.argmax(output, dim=-1).cpu().numpy().tolist())

            test_loss /= len(self.dataset)

            self.writer.add_scalar(f"{self.experiment_name}/LOSS/test", test_loss, self.step)

    def __call__(self, model, loss):
        return self.forward(model, loss)


## VAE model

In [10]:
class VAE(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

    def __init__(self, latent_dim, input_dim, hidden_dim=200, n_layers=2):
        super().__init__()

        self.latent_dim = latent_dim

        hidden_dims = list(np.linspace(input_dim, hidden_dim, n_layers, dtype=int))

        self.fc_mu = torch.nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_var = torch.nn.Linear(hidden_dims[-1], latent_dim)

        self.encoder_input = torch.nn.Linear(input_dim, hidden_dims[0])
        self.decoder_input = torch.nn.Linear(latent_dim, hidden_dims[-1])

        encoder_modules = []
        decoder_modules = []

        for i in range(len(hidden_dims) - 1):
            encoder_modules.append(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_dims[i], hidden_dims[i + 1]),
                    torch.nn.BatchNorm1d(hidden_dims[i + 1]),
                    torch.nn.LeakyReLU())
            )

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            decoder_modules.append(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_dims[i], hidden_dims[i + 1]),
                    torch.nn.BatchNorm1d(hidden_dims[i + 1]),
                    torch.nn.LeakyReLU())
            )

        self.encoder = torch.nn.Sequential(self.encoder_input, *encoder_modules)
        self.decoder = torch.nn.Sequential(self.decoder_input, *decoder_modules)

        self.final_layer = torch.nn.Sequential(
                            torch.nn.Linear(hidden_dims[-1], hidden_dims[-1]),
                            torch.nn.BatchNorm1d(hidden_dims[-1]),
                            torch.nn.LeakyReLU(),
                            torch.nn.Linear(hidden_dims[-1], input_dim))

    def encode(self, input):
        """
        Generates distribution provided by input.
        Args:
            input: Tensor - the matrix of shape batch_size x input_dim.
        Returns:
            List[mu, log_var] - the normal distribution parameters.
            mu: Tensor - the matrix of shape batch_size x latent_dim.
            sigma: Tensor - the matrix of shape batch_size x latent_dim.
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z):
        """
        Maps the given latent codes onto the image space.
        Args:
            z: Tensor - the matrix of shape batch_size x latent_dim.
        Returns:
            Tensor - decoded sample.
        """
        result = self.decoder(z)
        result = self.final_layer(result)

        return result

    def sample_z(self, mu, logvar):
        """
        Generates sample from normal distribution N(mu, var).
        Args:
            mu: Tensor - the matrix of shape batch_size x latent_dim.
            sigma: Tensor - the matrix of shape batch_size x latent_dim.
        Returns:
            Tensor - the tensor of shape batch_size x latent_dim - samples from normal distribution in latent space.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input, **kwargs):
        mu, log_var = self.encode(input)
        z = self.sample_z(mu, log_var)

        return [self.decode(z), input, mu, log_var]

    def loss(self, *args, **kwargs):
        """
        Computes the VAE loss function.
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        recons_loss = torch.nn.functional.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) # KL-divergention

        loss = recons_loss + kld_loss
        return loss

    def generate_samples(self, num_samples:int, **kwargs):
        """
        Samples from the latent space and return the corresponding image space map.
        Args:
            num_samples: int - the number of samples, witch need to generate.
        Returns:
            Tensor - the matrix of shape num_samples x input_dim.
        """
        z = torch.randn(num_samples, self.latent_dim)
        z = z.to(self.device)

        samples = self.decode(z)
        return samples

    def generate(self, x, **kwargs):
        """
        Generate decoded sample after encoding.
        Args:
            x: Tensor - the matrix of shape batch_size x input_dim.
        Returns:
            Tensor - decoded sample after encoding of x.
        """
        return self.forward(x)[0]


## Training with different parameters

In [11]:
optimizer = torch.optim.Adam


In [12]:
grid = ParameterGrid({
    'hidden_dim': [64, 512, 1024],
    'input_dim': [2, 64, 256],
    'num_layers': [2, 16, 64],
    'latent_dim': [2, 128, ],
})


In [None]:
for set in grid:
    synthetic_dataset = GaussianClustersDataset(num_samples=40000, num_clusters=5, dim=set['input_dim'])
    synthetic_train, synthetic_test = torch.utils.data.random_split(synthetic_dataset, [0.8, 0.2])

    autoencoder = VAE(latent_dim=set['latent_dim'], input_dim=set['input_dim'], hidden_dim=set['hidden_dim'])
    autoencoder.to(device)

    writer = SummaryWriter(log_dir = f"tensorboard/in={set['input_dim']}_hi={set['hidden_dim']}_lat={set['latent_dim']}_lay={set['num_layers']}")
    call = callback(writer, synthetic_test, torch.nn.MSELoss(), experiment_name="Overall", delimeter=100)

    trainer(count_of_epoch=10,
            batch_size=256,
            dataset=synthetic_train,
            model=autoencoder,
            loss_function=autoencoder.loss,
            optimizer = optimizer,
            lr = 0.001,
            callback = call)

#     if set['input_dim'] == 2:
#             distr = autoencoder.generate_samples(1000)
#             x, y = np.array(distr.detach().cpu()).T

#             plt.figure(figsize=(5, 4), dpi=100)
#             decoding = autoencoder.generate(distr).detach().cpu()

#             plt.scatter(x, y, label='Input')
#             plt.scatter(decoding[:, 0], decoding[:, 1], label='Generated')
#             plt.axis('equal')
#             plt.legend()
#             writer.add_figure('Visual', plt.gcf())
#             plt.show()


## Tensorboard

In [None]:
%tensorboard --logdir ./tensorboard/
