# Variational Autoencoder for Augmentation of Laser-Based Laryngeal Imaging

For the deep-learning-based algorithm to match features via a registration task, it is essential to apply intense data augmentation to the training data set. The training data consists of moving images $m(x)$ that represent the spatial configuration of laser points projected onto the vocal fold surface.
The foundation for the images $m(x)$ are the x-y-coordinates of each single laser point within the image, as $m(x)$ is generated by plotting the single laser points and then smoothing the image. To create intense augmentation we want to train a variational autoencoder (VAE) to then generate images that are variations of the images of the training set and represent feasible configurations of laser points projected onto a vocal fold.

In a VAE a decoder-encoder architecture is utilized. The coding is split up into a vector representing the mean $\mu$ and one representing the logarithmic variance $\log{(\mathbf{\sigma^2})}$. The distribution is assumed to be gaussian. Behind the coding a sampling layer exists that randomly samples a value from the distribution. Like that it is possible to create a smooth latent space. The network can be trained in a unsupervised manner, as the the decoder should learn to reconstruct the input of the encoder as good as possible. After training, only the decoder part will be used. Inputing random points in the latent space will then generate x- and y-coordinates that are part of the underlying trained distribution based on the ground truth.

## Import Statements
The notebook was developed on Keras using the Tensorflow 2.2.0 backend.

In [None]:
import os
import glob
import json
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from tensorflow.keras.callbacks import CSVLogger
from sklearn.manifold import TSNE

## Hardware Configuration
Check for GPU and allow memory growth such that limitations for training are reduced. 

In [None]:
if len(tf.config.experimental.get_visible_devices('GPU')):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Functionality
The following functionality helps to create the images for training the registration task.

In [None]:
def apply_smoothing(image, sigma=1.0, sigma_back=10.0):
    image_orig = image
    image = gaussian_filter(image, sigma=sigma)
    image_back = gaussian_filter(image, sigma=sigma_back)

    image = (image / image.max()) * 255
    image_back = (image_back / image_back.max()) * 255
    image = 0.3 * image_orig + 0.3 * image + 0.3 * image_back

    return image

In [None]:
def reconstruct_image(x_position, y_position):
    image = np.zeros((image_height, image_width))
    index_2_xy = dict()
    point_index = 0
    for x, y in zip(x_position.flatten(), y_position.flatten()):
        x = int(x*image_width)
        y = int(y*image_height)
        
        if x >= 5 and y>= 5:
            image[y][x] = 1
            index_2_xy[str(point_index)] = [x, y]
            
        point_index += 1
        
    image = apply_smoothing(image, sigma=2, sigma_back=15)
    return image, index_2_xy

## Model parameters
The grid dimensions have to be know, as well as the image dimensions for scaling. The depth of the input and output layer is 2 here, as we will have one channel representing x-coordinates and a second channel representing y-coordinates. 

In [None]:
grid_width = 18
grid_height = 18

image_width = 728
image_height = 728

inout_layers = 2

## Hyperparameters
The dimension of the latent space can be adapted to optimize the network. 

In [None]:
latent_dim = 8
epochs = 100
batch_size = 4
coding_loss_factor =0.2
epsilon_factor = 1.0

## Sampling Layer
In the sampling layer the coding $\mathbf{\mu}$ and $\log{(\mathbf{\sigma^2})}$ are used to randomly sample a point from the underlying distribution. Due to this randomness that occurs within each batch and along all epochs the decoder is trained on smooth distributions. Each coding $\mathbf{z}$ used to train the decoder is computed as

$$\mathbf{z} = \mathbf{\mu} + e^{0.5\mathbf{\log{(\mathbf{\sigma^2})}}}\mathbf{\epsilon}f$$

where $\epsilon$ is a random vector from a normal distribution and $f$ is a hyperparameter to control the amount of randomness.

In [None]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon * epsilon_factor

## Encoder Model
In the encoder strided convolutional layers are applied to reduce the dimensionality of the input. For a grid with dimension 18x18 the input shape is increased to 20x20 and the input images are padded. Like that it is easier to ensure symmetry between encoder an decoder. The encoder is a keras model having the mean, variance and a sampled value as output.

In [None]:
encoder_inputs = keras.Input(shape=(grid_width+2, grid_height+2, inout_layers))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

## Decoder Model
The decoder takes only a sampeled value or a randomly selected value with the size of the latent space as input. 
The dimnesion of the input of the encoder is then reconstructed. The decoder is constructed as a keras model.

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
y_val = (grid_height+2)//4
x_val = (grid_width+2)//4

x = layers.Dense(y_val*x_val*64, activation="relu")(latent_inputs)
x = layers.Reshape((y_val, x_val, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(inout_layers, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## VAE Model
The VAE is composed by the encoder and decoder. The decoder only takes the random coding $\mathbf{z}$ from the encoder. However, the codings $\mathbf{\mu}$ and $\log{(\mathbf{\sigma^2})}$ are used to compute a loss $\epsilon_{coding}$ for regularizing and ensuring a smooth distribution. A binary crossentropy is used to define a reconstruction loss $\epsilon_{reconstruction}$. 
We are doing a regression here, based on probabilities. The overall loss is then:

$$\epsilon = \epsilon_{reconstruction} + \epsilon_{coding}$$

In [None]:
class VariationalAutoEncoder(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VariationalAutoEncoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
            
        # Gradient tape allows to "record" a function that will be derived
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            
            # Reconstuction Loss
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= grid_height+2 * grid_height+2
            
            # Regularization Loss for well formed latent space and no overfitting
            coding_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            coding_loss = tf.reduce_mean(coding_loss)
            coding_loss *= -(0.5 * coding_loss_factor)
            
            # Total Loss
            total_loss = reconstruction_loss + coding_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "coding_loss": coding_loss,
        }

## Prepare data
Load the data and scale them to be between 0.0 an 1.0. Further use zero-padding to get a 20x20 shape for this example.

In [None]:
# Define a path from where to load the data
base_path = 'Data/LASTEN/train'

# Load data into numpy arrays
def sort_key(path):
    return int(path.split(os.sep)[-1].split(".")[0].split("_")[0]) 
    
globs = glob.glob(base_path+'/*.json')
files = sorted(globs, key=sort_key)
file_length =len(files)
x_position = np.full((file_length, grid_height, grid_width), 0)
y_position = np.full((file_length, grid_height, grid_width), 0)
for file_id, file in enumerate(files):
    with open(file) as json_file:
        data = json.load(json_file)
        
        for key, value in data.items():
            key = int(key)
        
            
            y = key // grid_height 
            x = key % grid_width

            x_position[file_id][y][x] = value[0]
            y_position[file_id][y][x] = value[1]


# Define an offset
offset = 0.0

# Non existing points mapping
x_position = np.where(x_position<=0, -offset, x_position) + offset
y_position = np.where(x_position<=0, -offset, y_position) + offset

x_position = x_position / (image_width + offset)
y_position = y_position / (image_height + offset)

# Zero padding
x_position = np.pad(x_position,[(0,0),(1,1),(1,1)],constant_values=0)
y_position = np.pad(y_position,[(0,0),(1,1),(1,1)],constant_values=0)

x_position = x_position[:,:,:,np.newaxis]
y_position = y_position[:,:,:,np.newaxis]

xy_data = np.concatenate((x_position, y_position), axis=3)
print("Shape of 'xy_data': {}".format(xy_data.shape))

## Training
Initialize the VAE, compile it and fit the model.

In [None]:
vae = VariationalAutoEncoder(encoder, decoder)

vae.compile(optimizer=keras.optimizers.Adam())
logger = CSVLogger("VAE_log")
history = vae.fit(xy_data, epochs=epochs, batch_size=batch_size)
#vae.save_weights('weights/vae/vae_100_epochs_gt')
#vae.load_weights('weights/vae/vae_100_epochs_gt')

In [None]:
if history:
    print('Mean reconstruction loss over last 20 epochs: {}'.format(np.array(history.history['reconstruction_loss'][-20:]).mean()))
    print('Mean coding loss over last 20 epochs: {}'.format(np.array(history.history['coding_loss'][-20:]).mean()))
    plt.semilogy(history.history['reconstruction_loss'], label='Reconstruction Loss')
    plt.title('Crossentropy Loss for Reconstruction of Input images')
    plt.ylabel('Crossentropy')
    plt.xlabel('No. epoch')
    plt.legend(loc="upper left")
    plt.show()

## Predict images and compare with ground truth
Get the mean value for ground truth images and decode the value to reconstruct the images.

In [None]:
res_encoder = encoder.predict(xy_data)
res_decoder = decoder.predict(res_encoder[0])
plt.rcParams["figure.figsize"] = (7,7)

### Predicition of x-coordinates by autoencoder
Out of all 8 recordings from the dataset one frame is displayed

In [None]:
for i in range(8):
    val = 441 + i
    plt.subplot(val)
    x = res_decoder[i*20,:,:,1][1:-1,1:-1]
    plt.imshow(x)
    plt.axis('off')
    plt.tight_layout()

### Ground-truth of x-coordinates
Out of all 8 recordings from the dataset one frame is displayed.

In [None]:
for i in range(8):
    val = 441 + i
    plt.subplot(val)
    x = xy_data[i*20,:,:,1][1:-1,1:-1]
    plt.imshow(x)
    plt.axis('off')
    plt.tight_layout()

### Predicition of x-y-coordinates and display in image space 
Out of all 8 recordings from the dataset one frame is displayed

In [None]:
for i in range(8):
    val = 441 + i
    plt.subplot(val)
    x = res_decoder[i*20,:,:,0]
    y = res_decoder[i*20,:,:,1]
    plt.imshow(reconstruct_image(x, y)[0],cmap='gist_gray')
    plt.axis('off')
    plt.tight_layout()

### Ground truth of x-y-coordinates and display in image space 
Out of all 8 recordings from the dataset one frame is displayed

In [None]:
for i in range(8):
    val = 441 + i
    plt.subplot(val)
    x = xy_data[i*20,:,:,0]
    y = xy_data[i*20,:,:,1]
    plt.imshow(reconstruct_image(x, y)[0],cmap='gist_gray')
    plt.axis('off')
    plt.tight_layout()

## Postprocessing
Apply a bilateral filter to denoise the coordinate images. Further irregularities in x- and y-direction are removed.

In [None]:
from skimage.restoration import denoise_bilateral
plt.rcParams["figure.figsize"] = (7,7)

def postprocess_coordinate_images(results, index):
    x = results[index,:,:,0]
    y = results[index,:,:,1]
    
    # Unpad
    x = x[1:-1,1:-1]
    y = y[1:-1,1:-1]
    
    x_max_orig = x.max()
    y_max_orig = y.max()
    
    # Smooth and equalize possible offset due to smoothing
    x = denoise_bilateral(x, mode='edge',sigma_spatial=100, win_size=3)
    y = denoise_bilateral(y, mode='edge',sigma_spatial=100, win_size=3)
    x = x - x.min()
    y = y - y.min()
    x = x * (x_max_orig / x.max())
    y = y * (y_max_orig / y.max())
    
    # Sort out irregularities in x
    for i in range(len(x)):        
        max_value=0.0
        for j in range(len(x[i])):
            if x[i][j] <= max_value:
                x[i][j] = 0.0
                y[i][j] = 0.0
            else: 
                max_value = x[i][j]
            
    # Sort out irregularities in y
    for i in range(len(y[0])):        
        max_value=0.0
        for j in reversed(range(len(y))):
            if y[j][i] <= max_value:
                x[j][i] = 0.0
                y[j][i] = 0.0
            else:
                max_value = y[j][i]
    
    return x, y        

## Generate an image based on a vector from the latent space
Use this functionality to explore the latent space.

In [None]:
%%capture
%matplotlib inline
import ipywidgets as widgets
from ipywidgets import interact, widgets
from IPython.display import display

def f(**sliders):
    image_id = 0
    
    sliders_float = [float(slider[1]) for slider in sliders.items()]
    
    res_encoder_offset = np.array([sliders_float])
    res_decoder = decoder.predict(res_encoder[0][image_id] + res_encoder_offset)
    
    x, y = postprocess_coordinate_images(res_decoder, image_id)

    img = reconstruct_image(x, y)[0]
    ax.imshow(img, cmap="gray")
    display(fig)

fig = plt.figure(figsize=(6, 4))
ax = fig.add_axes(plt.axes())
fig.canvas.draw()
ax.axis('off')


sliders = {str(i): widgets.FloatSlider(min=-2.0,max=2.0,step=0.5,value=0.0, continuous_update=False) for i in range(latent_dim)}


In [None]:
_ = interact(f, **sliders)

## Draw samples from the distribution for visualization

There are several possibilties to draw samples from the latent space.
1. Predict the latent vector of one sample image and pertubate it.
2. Linearly Interpolate between neighbors.
3. Linearly Extrapolate with neighbors.

We will show possbility 2 here to visualize the manifold.

### Linearly Interpolate between neighbors
First we define a function to linearly interpolate a coding between two existing codings with a factor $\alpha$ in $[0,1]$. Therefore, we define:
$$ \hat{\mathbf{z}} = (\mathbf{\mu}_2 - \mathbf{\mu}_1)\alpha + \mathbf{\mu}_1$$

In [None]:
def interpolate_codings(first_coding, second_coding, alpha):
    return (second_coding - first_coding) * alpha + first_coding

Here we select four ground truth images for which we first compute the coding through the encoder. Then we sample with a factor of $\alpha$, values between the first and the second image and the third and the fourth image. After, we sample again between the previously sampled points to cover approximatley the space between all ground truth images. 

In [None]:
first_image_id = 0
second_image_id = 108
third_image_id = 54
fourth_image_id = 18

first_coding = res_encoder[0][first_image_id]
second_coding = res_encoder[0][second_image_id]
third_coding = res_encoder[0][third_image_id]
fourth_coding = res_encoder[0][fourth_image_id]

alphas = list(np.arange(0.0,1.1,0.2))
size = len(alphas)

first_2_second_codings = np.array([interpolate_codings(first_coding, second_coding, alpha) for alpha in alphas])
third_2_fourth_codings = np.array([interpolate_codings(third_coding, fourth_coding, alpha) for alpha in alphas])

codings = np.zeros((size, size, latent_dim))

for i in range(size):
    first_2_second_coding = first_2_second_codings[i]
    third_2_fourth_coding = third_2_fourth_codings[i]
    
    vertical_codings = np.array([interpolate_codings(first_2_second_coding, third_2_fourth_coding, alpha) for alpha in alphas])
    
    codings[:, i, :] = vertical_codings

codings = codings.reshape((size * size, latent_dim))
res_decoder = decoder.predict(codings)

x_coords = list()
y_coords = list()

for i in range(size * size):
    x, y = postprocess_coordinate_images(res_decoder, i)
    x_coords.append(x)
    y_coords.append(y)

We can now display the learned manifold:

In [None]:
plt.rcParams["figure.figsize"] = (40,40)

for i in range(size * size):
    x = x_coords[i]
    y = y_coords[i]
    
    plt.subplot(size, size, i+1)
    plt.imshow(reconstruct_image(x, y)[0],cmap='gist_gray')
    plt.axis('off')
    plt.tight_layout()

plt.show()
#plt.savefig('learned_manifold.png')

## Draw samples from the distribution for generating a data set
Use either sampling by interpolation or by noise

In [None]:
# Decide to sample by interpolation or by noise
is_interpolation = True

In [None]:
# Sample by interpolation
if is_interpolation:
    # Amount of between how many sample points from recordings an interpolation is carried out
    sub_samples = 10
    # Amount of how many interpolations are made between two points (alpha value).
    inter_samples = 3
    # Get the codings
    codings = res_encoder[0]
    # Split the recordings
    recordings = [codings[0:20],
                  codings[20:40],
                  codings[40:60],
                  codings[60:80],
                  codings[80:100],
                  codings[100:120],
                  codings[120:140],
                  codings[140:160]]

    def interpolate_codings(first_coding, second_coding, alpha):
        return (second_coding - first_coding) * alpha + first_coding

    def extrapolate_codings(first_coding, second_coding, alpha):
        return (first_coding - second_coding) * alpha + first_coding

    def get_random_list(samples, sub_samples):
        random_list = list(range(samples))
        random.shuffle(random_list)
        random_list = random_list[0:sub_samples]
        return random_list

    interpolated_codings = np.zeros((28*sub_samples*inter_samples, 8))#*2
    alphas = np.linspace(0,1,inter_samples+2)[1:-1]

    counter = 0
    while recordings:
        first_recording = recordings.pop(0)

        for second_recording in recordings:
            first_random_points = get_random_list(20, sub_samples)
            second_random_points = get_random_list(20, sub_samples)

            while first_random_points:
                first_point = first_random_points.pop(0)
                second_point = second_random_points.pop(0)
                first_coding = first_recording[first_point]
                second_coding = second_recording[second_point]
                inter_codings = np.array([interpolate_codings(first_coding, second_coding, alpha) for alpha in alphas])
                start_index_inter = counter*inter_samples #*2
                end_index_inter = counter*inter_samples + inter_samples          
                interpolated_codings[start_index_inter:end_index_inter] = inter_codings

                counter +=1

    res_decoder = decoder.predict(interpolated_codings)
    variation_index = 40000
    for i in range(len(interpolated_codings)):
        new_x, new_y = postprocess_coordinate_images(res_decoder, i)
        image, index_2_xy = reconstruct_image(new_x, new_y)

        # Store JSON
        json_file = json.dumps(index_2_xy)
        f = open(base_path + os.sep + "{}.json".format(variation_index), "w")
        f.write(json_file)
        f.close()

        # Store .png image
        plt.imsave(base_path + os.sep + "{}_mov.png".format(variation_index), image, cmap="gray")
        variation_index += 1
        
# Sample Noise
else:
    # Sample size per observation
    amount = 4
    
    variation_index = 40000
    for i in range(amount):
        epsilon = np.random.normal(size=res_encoder[0].shape)
        variation = res_encoder[0] + np.exp(0.5 * res_encoder[1]) * epsilon
        res_decoder = decoder.predict(variation)

        for i in range(160):
            new_x, new_y = postprocess_coordinate_images(res_decoder, i)
            image, index_2_xy = reconstruct_image(new_x, new_y)
            # Store JSON
            json_file = json.dumps(index_2_xy)
            f = open(base_path + os.sep + "{}.json".format(variation_index), "w")
            f.write(json_file)
            f.close()
            # Store .png image
            plt.imsave(base_path + os.sep + "{}_mov.png".format(variation_index), image, cmap="gray")
            variation_index += 1