In [1]:
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Dense, Flatten, Reshape
from tensorflow.keras import Model, Sequential
from tensorflow.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import os


In [2]:
DATA_FILE = "x.npy"

batch_size = 32

## Load Data
Load data and split it into batches. Split into batches because the model(s) has/have a custom training loop and is trained one batch at a time so splitting the data now makes it a little easier later.

In [3]:
x = np.load(DATA_FILE)
img_shape = x.shape[1:]
x = np.split(x[:len(x)//batch_size * batch_size], len(x)//batch_size)

Keep track of tensors and layers becuase layers are used to define `Sequential` models and tensors are used to define the input and output of `Model` models

In [4]:
encoder_conv_tensors = []
encoder_conv_layers = []

decoder_conv_tensors = []
decoder_conv_layers = []

## Define the Models
Imagine the autoencoder as a U shape, with the top two points in the U being the first layer of the encoder and the last layer of the decoder. The layers are created starting at the top of the U, going down.

In [5]:
input_tensor = Input(shape=img_shape)

num_layers = int(math.log2(img_shape[0])) # All layers have a stride of 2 and the images are preprocessed to be squares with a power of two side length
                                          # so `num_layers` layers will reduce the images down to a 1x1 feature stack
for i in range(num_layers):
    num_encoder_filters = 3*2**(i+1) # 
    num_decoder_filters = 3*2**i
    if i == 0: # The first layer in the encoder has a 5x5 kernel, and the last layer in the decoder has a sigmoid ouput
        encoder_layer = Conv2D(num_encoder_filters, 5, strides=2, padding="same", activation='selu')
        encoder_tensor = encoder_layer(input_tensor)
        decoder_layer = Conv2DTranspose(3, 5, strides=2, padding="same", activation='sigmoid')
        decoder_tensor = decoder_layer(encoder_tensor)
    else: # All other layers have 3x3 kernels and selu activation
        encoder_layer = Conv2D(num_encoder_filters, 3, strides=2, padding="same", activation='selu')
        encoder_tensor = encoder_layer(encoder_conv_tensors[-1])
        decoder_layer = Conv2DTranspose(num_decoder_filters, 3, strides=2, padding="same", activation='selu')
        decoder_tensor = decoder_layer(encoder_tensor)
    
    encoder_conv_tensors.append(encoder_tensor)
    encoder_conv_layers.append(encoder_layer)

    decoder_conv_tensors.append(decoder_tensor)
    decoder_conv_layers.append(decoder_layer)

decoder_conv_layers.reverse()

# Stacked autoencoder model to do inference on, but not to train
encoder = Sequential(encoder_conv_layers)
decoder = Sequential(decoder_conv_layers)
autoencoder = Sequential([encoder, decoder])
# Make a copy of the model to train in a stacked way to compare to recursively trained model
autoencoder = keras.models.clone_model(autoencoder)
autoencoder.compile(optimizer="adam", loss="MSE", metrics=["mean_absolute_error"])

# Model that takes image as input and outputs all the intermediate states
hidden_encoders = Model(inputs=[input_tensor], outputs=encoder_conv_tensors[:-1])

# Define the shallow models that get trained individually
models = []
for i in range(num_layers):
    model = Model(inputs=[input_tensor], outputs=[decoder_conv_tensors[i]])
    for layer in model.layers[:-2]: # For each shallow model freeze the weights in all but the last two layers
        layer.trainable = False
    model.compile(optimizer="adam", loss="MSE", metrics=["mean_absolute_error"])
    models.append(model)
    # Print the summary for every recursive model
    model.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 64, 64, 6)         456       
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 128, 128, 3)       453       
Total params: 909
Trainable params: 909
Non-trainable params: 0
_________________________________________________________________
Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 64, 64, 6)         456       
__________________________________

## Train the models


In [6]:
def train_step(batch):
    targets = [batch] # first element of target is the input images
    hidden_targets = hidden_encoders.predict(batch) # the rest of the elements of target are the intermediate outputs
    targets.extend(hidden_targets)

    # Go through every shallow model and train it to reproduce the corrosponding intermediate output
    for i, model in enumerate(models):
        model.train_on_batch(batch, targets[i])
    
    # Train the stacked autoencoder for one batch as well
    autoencoder.train_on_batch(batch, batch)

epochs = 10
for i in range(epochs):
    print("\nEpoch %d/%d" % (i, epochs))
    for batch in tqdm(x):
        train_step(batch)

  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 0/10
100%|██████████| 413/413 [02:03<00:00,  3.34it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 1/10
100%|██████████| 413/413 [01:56<00:00,  3.55it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 2/10
100%|██████████| 413/413 [01:56<00:00,  3.55it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 3/10
100%|██████████| 413/413 [01:55<00:00,  3.56it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 4/10
100%|██████████| 413/413 [01:55<00:00,  3.56it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 5/10
100%|██████████| 413/413 [01:56<00:00,  3.55it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 6/10
100%|██████████| 413/413 [01:55<00:00,  3.59it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 7/10
100%|██████████| 413/413 [01:52<00:00,  3.68it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 8/10
100%|██████████| 413/413 [01:52<00:00,  3.68it/s]
  0%|          | 0/413 [00:00<?, ?it/s]
Epoch 9/10
100%|██████████| 413/413 [01:52<00:00,  

In [10]:
# encoding_depth = 0

display_images = np.array(x[0][:10], dtype=float)

# encoder = Sequential(encoder_conv_layers)
# decoder = Sequential(decoder_conv_layers)

recursive_autoencoder = Sequential([encoder, decoder]) # The full autoencoder model that was trained in the recursive way

reconstructions = recursive_autoencoder.predict(display_images) # Reconstructions from the recursive autoencoder
basic_reconstructions = autoencoder.predict(display_images) # Reconstructions form the stacked autoencoder

# Make the plot

fig, axs = plt.subplots(5, 3)

titles = ["Original", "Stacked", "Recursive"]

for r, row in enumerate(axs):
    imgs = [x[0][r], basic_reconstructions[r], reconstructions[r]]
    for c, ax in enumerate(row):
        ax.imshow(imgs[c])
        if r == 0:
            ax.set_title(titles[c])

fig.set_size_inches(6, 9)
fig.savefig("stacked_vs_recursive.png")
fig.show()

ValueError: Unsupported dtype