# Preface

In this notebook, we explore the use of autoencoders for image compression and denoising.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from pathlib import Path
from PIL import Image
sns.set(font_scale=1.5, style='dark')

# The Pokemon Dataset

So far our applications in this class have been rather serious. Here, to demonstrate the use of autoencoders we will use a fun dataset. The pokemon dataset!

See [here](https://www.kaggle.com/vishalsubbiah/pokemon-images-and-types) for more information.

In [None]:
import kaggle
kaggle.api.authenticate()

kaggle.api.dataset_download_files(
    'vishalsubbiah/pokemon-images-and-types',
    path='./pokemon',
    quiet=False,
    unzip=True,
    force=False,
)

The images are in multiple formats including png and jpg formats. So we will do some pre-processing and change all of them into (120, 120, 3) arrays, representing a RGB image.

In [None]:
names = []
images = []

fill_color = (255, 255, 255)

for img in os.listdir('./pokemon/images/images'):
    im = Image.open('./pokemon/images/images/{}'.format(img))
    names.append(img.split('.')[0])
    if img.split('.')[1] == 'png':
        im = im.convert("RGBA")
        if im.mode in ('RGBA', 'LA'):
            bg = Image.new(im.mode[:-1], im.size, fill_color)
            bg.paste(im, im.split()[-1])
            im = bg
    images.append(np.asarray(im))
images = np.asarray(images) / 255.0

We will write a function to plot the images.

In [None]:
def plot_images(images, n_plots=5):
    fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots, 4))

    for i, a in zip(images, ax):
        a.imshow(i)
        a.axis('off')

In [None]:
plot_images(images=images)

We keep a testing set to evaluate our autoencoders' ability to generalize.

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
x_train, x_test = train_test_split(images, test_size=0.1, random_state=123)

In [None]:
x_train.shape

# Fully Connected Autoencoder

We start with the simplest autoencoder consisting of fully connected layers alone.

In [None]:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tqdm.keras import TqdmCallback

In [None]:
# Encoder

encoder = Sequential()
encoder.add(Flatten(input_shape=(120, 120, 3)))
encoder.add(Dense(units=128, activation='relu'))
encoder.add(Dense(units=32, activation='relu'))

# Decoder

decoder = Sequential()
decoder.add(Dense(units=128, activation='relu', input_shape=(32, )))
decoder.add(Dense(units=120*120*3, activation='sigmoid'))
decoder.add(Reshape(target_shape=(120, 120, 3)))

autoencoder = Sequential([encoder, decoder])

In [None]:
autoencoder.summary()

As we can see, there are over 11 million parameters! This is a huge network. Let us compile and train it.

In [None]:
def train_and_save(model, path, **kwargs):
    if path.exists():
        model.load_weights(str(path))
    else:
        model.compile(loss='binary_crossentropy', optimizer='adam')
        _ = model.fit(**kwargs)
        model.save_weights(str(path))

In [None]:
train_and_save(
    model=autoencoder,
    path=Path('./pokemon_ae_fcnn.h5'),
    x=x_train,
    y=x_train,
    batch_size=64,
    validation_data=(x_test, x_test),
    verbose=0,
    epochs=200,
    callbacks=[TqdmCallback(verbose=1)],
)

Let us check the reconstruction results on the test set.

In [None]:
x_test_pred = autoencoder.predict(x_test)

In [None]:
plot_images(x_test)
plot_images(x_test_pred)

Observe that although the reconstructed images are not random, they are far from satisfactory. 

In fact, we can check the performance on the training set to confirm that this is not a problem of overfitting.

In [None]:
x_train_pred = autoencoder.predict(x_train)

In [None]:
plot_images(x_train)
plot_images(x_train_pred)

# Convolutional Autoencoder

Since we are dealing with images, it is likely that the fully connected network can capture the features much better.

In [None]:
from tensorflow.keras.layers import AveragePooling2D, Conv2D, UpSampling2D

For the decoder, we will decrease dimensions using the pooling operation, preserving the structure of the image.

In [None]:
# Encoder

encoder = Sequential()
encoder.add(
    Conv2D(
        filters=16,
        kernel_size=5,
        padding='same',
        activation='relu',
        input_shape=(120, 120, 3)))
encoder.add(AveragePooling2D())
encoder.add(
    Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))
encoder.add(AveragePooling2D())
encoder.add(
    Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))
encoder.add(AveragePooling2D())
encoder.add(
    Conv2D(filters=16, kernel_size=3, padding='same', activation='relu'))

In [None]:
# Decoder

decoder = Sequential()
decoder.add(
    Conv2D(
        filters=64,
        kernel_size=3,
        padding='same',
        activation='relu',
        input_shape=(15, 15, 16)))
decoder.add(UpSampling2D())
decoder.add(
    Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))
decoder.add(UpSampling2D())
decoder.add(
    Conv2D(filters=16, kernel_size=3, padding='same', activation='relu'))
decoder.add(UpSampling2D())
decoder.add(
    Conv2D(filters=3, kernel_size=5, padding='same', activation='sigmoid'))

In [None]:
autoencoder = Sequential([encoder, decoder])

Check that we have much fewer parameters this time!

In [None]:
autoencoder.summary()

In [None]:
train_and_save(
    model=autoencoder,
    path=Path('./pokemon_ae_cnn.h5'),
    x=x_train,
    y=x_train,
    batch_size=64,
    validation_data=(x_test, x_test),
    verbose=0,
    epochs=80,
    callbacks=[TqdmCallback(verbose=1)],
)

Let us now look at the results of the predictions

In [None]:
x_test_pred = autoencoder.predict(x_test)

In [None]:
plot_images(x_test)
plot_images(x_test_pred)

Much better! but still not perfect.

**Exercise**

Play with the above model to improve performance.

# Denoising using U-net

Often, we do not have to do all the architectural engineering ourselves. 

A very oft-used CNN autoencoder-type architecture is the *U-net*, developed in [this paper](https://arxiv.org/abs/1505.04597).

It is very often the case that well-known architectures have been implemented by others in keras. This is the case for U-net. We will use the following [package](https://arxiv.org/abs/1505.04597). You can install it by issuing
```
$pip install keras-unet
```

In [None]:
from keras_unet.models import custom_unet

We will train a denoising autoencoder, by minimizing the difference between a noise-corrupted input and a clean input, i.e. we minimize
$$
    L(\mathbf{x}, \mathrm{Decoder}(\mathrm{Encoder(\mathbf{x + \mathrm{Noise}})}))
$$

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

We write a simple function to add noise to the input, scaled by the `std` parameter.

In [None]:
def add_gaussian_noise(x, std=0.1):
    x_noisy =  x + np.random.normal(scale=std, size=x.shape)
    return np.clip(x_noisy, 0.0, 1.0)

In [None]:
autoencoder = custom_unet(
    input_shape=(120, 120, 3),
    num_layers=3,
    num_classes=3,
)

We can do on-the-fly noise generation by the `ImageDataGenerator` class. We have previously used this for data augmentation.

In [None]:
data_gen = ImageDataGenerator(preprocessing_function=add_gaussian_noise)

In [None]:
generator = data_gen.flow(x=x_train, y=x_train, batch_size=64)

In [None]:
train_and_save(
    model=autoencoder,
    path=Path('./pokemon_ae_denoise_unet.h5'),
    x=generator,
    validation_data=(x_test, x_test),
    verbose=0,
    epochs=80,
    callbacks=[TqdmCallback(verbose=1)],
)

Let us now test our model's performance on noise-corrupted test data.

In [None]:
x_test_noisy = add_gaussian_noise(x_test, std=0.1)
x_test_pred_noisy = autoencoder.predict(x_test_noisy)

In [None]:
plot_images(x_test)
plot_images(x_test_noisy)
plot_images(x_test_pred_noisy)

# Exercise

Explore the performance of the model under different noise distributions, e.g.
  * correlated Gaussian
  * uniform

How do we make the model more robust to different types of perturbations?