In [5]:
import datetime

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from data import example_to_tensor
from train import EarlyStopping
from utils import plot_slice, plot_animated_volume

print(f"Tensorflow: {tf.__version__}")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

Tensorflow: 2.2.0


In [6]:
# Hyperparameters
epochs = 1000
learning_rate = 0.0001
patience = 3
batch_size = 2
xy_size = 128  # downscale 4
z_size = 96  # downscale 4 (nrrd)
# z_size = 244  # downscale 4 (tcia)
# xy_size = 256  # downscale 2
# z_size = 176  # downscale 2
# xy_size = 512   # original
# z_size = 368    # original

In [7]:
dataset = tf.data.TFRecordDataset("data/nrrd-0.25-float32.tfrecords")
dataset = dataset.map(lambda x: example_to_tensor(x, "float32"))
dataset = dataset.padded_batch(
    batch_size=2, padded_shapes=[z_size, xy_size, xy_size, 1],
)
# dataset = dataset.skip(10)
dataset = dataset.take(2)
# dataset = dataset.shuffle(buffer_size=10, reshuffle_each_iteration=True)
dataset

<TakeDataset shapes: (None, 96, 128, 128, 1), types: tf.float32>

In [8]:
scan = next(iter(dataset))
plot_animated_volume(scan)

In [10]:
encoder = keras.models.Sequential(
    [
        keras.layers.Conv3D(
            input_shape=[z_size, xy_size, xy_size, 1],
            filters=8,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
        keras.layers.Conv3D(
            filters=16,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
    ]
)
encoder.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d_2 (Conv3D)            (None, 96, 128, 128, 8)   224       
_________________________________________________________________
activation_2 (Activation)    (None, 96, 128, 128, 8)   0         
_________________________________________________________________
alpha_dropout_2 (AlphaDropou (None, 96, 128, 128, 8)   0         
_________________________________________________________________
max_pooling3d_2 (MaxPooling3 (None, 48, 64, 64, 8)     0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 48, 64, 64, 16)    3472      
_________________________________________________________________
activation_3 (Activation)    (None, 48, 64, 64, 16)    0         
_________________________________________________________________
alpha_dropout_3 (AlphaDropou (None, 48, 64, 64, 16)   

In [11]:
decoder = keras.models.Sequential(
    [
        keras.layers.UpSampling3D(
            input_shape=encoder.layers[-1].output.shape[1:], size=2,
        ),
        keras.layers.Conv3D(
            filters=16,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.UpSampling3D(size=2,),
        keras.layers.Conv3D(
            filters=8,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.Dense(1),
        keras.layers.Activation("sigmoid"),
    ]
)
decoder.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
up_sampling3d (UpSampling3D) (None, 48, 64, 64, 16)    0         
_________________________________________________________________
conv3d_4 (Conv3D)            (None, 48, 64, 64, 16)    6928      
_________________________________________________________________
activation_4 (Activation)    (None, 48, 64, 64, 16)    0         
_________________________________________________________________
alpha_dropout_4 (AlphaDropou (None, 48, 64, 64, 16)    0         
_________________________________________________________________
up_sampling3d_1 (UpSampling3 (None, 96, 128, 128, 16)  0         
_________________________________________________________________
conv3d_5 (Conv3D)            (None, 96, 128, 128, 8)   3464      
_________________________________________________________________
activation_5 (Activation)    (None, 96, 128, 128, 8)  

In [12]:
autoencoder = keras.models.Sequential([encoder, decoder])
# autoencoder.load_weights("models/autoencoder/20200723-103317/best_epoch_ckpt")
# autoencoder = keras.models.load_model("../../20200816-170759/")
autoencoder.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_1 (Sequential)    (None, 24, 32, 32, 16)    3696      
_________________________________________________________________
sequential_2 (Sequential)    (None, 96, 128, 128, 1)   10401     
Total params: 14,097
Trainable params: 14,097
Non-trainable params: 0
_________________________________________________________________


In [13]:
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(lr=learning_rate)

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/autoencoder/{current_time}/"
model_dir = f"models/autoencoder/{current_time}/"
ckpt_dir = model_dir + "best_epoch_ckpt"
writer = tf.summary.create_file_writer(log_dir)
train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
early_stopping = EarlyStopping(patience)

for epoch in tqdm(range(epochs), disable=False):
    for batch_features in tqdm(dataset, total=1, disable=True):
        with tf.GradientTape() as tape:
            predictions = autoencoder(batch_features)
            loss_value = loss_fn(predictions, batch_features)
        gradients = tape.gradient(loss_value, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
        with writer.as_default():
            for grad, param in zip(gradients, autoencoder.trainable_variables):
                tf.summary.histogram(param.name, param, step=epoch)
                # tf.summary.histogram(param.name + "/grad", grad, buckets=1, step=epoch)
            train_loss(loss_value)

    with writer.as_default():
        loss_mean = train_loss.result()
        print(f"Training loss: {loss_mean}")
        tf.summary.scalar("loss", loss_mean, step=epoch)

    train_loss.reset_states()

    early_stopping(loss_mean)
    if early_stopping:
        autoencoder.load_weights(ckpt_dir)
        autoencoder.save(model_dir)
        break
    else:
        autoencoder.save_weights(ckpt_dir)
else:
    # didn't stop for early stopping
    autoencoder.save(model_dir)

with writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"batch size = {batch_size}; "
        f"patience = {patience}; "
        f"learning rate = {learning_rate}",
        step=0,
    )

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=logs --bind_all

In [None]:
i = next(iter(dataset))
fig, ax = plt.subplots(ncols=3)
sample = 1
ax[0].imshow(tf.cast(i[sample, 0, :, :, 0], tf.float32), cmap="gray")
encoder_out = autoencoder.layers[0](i)
ax[1].imshow(tf.cast(encoder_out[sample, 0, :, :, 0], tf.float32), cmap="gray")
decoder_out = autoencoder.layers[1](encoder_out)
ax[2].imshow(tf.cast(decoder_out[sample, 0, :, :, 0], tf.float32), cmap="gray")