In [1]:
import time
import datetime
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from data import example_to_tensor, normalize
from train import EarlyStopping
from utils import plot_slice, plot_animated_volume

print(f"Tensorflow: {tf.__version__}")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

Tensorflow: 2.3.0


In [13]:
verbose_training = True
# Hyperparameters
epochs = 1
learning_rate = 0.0005
patience = 20
batch_size = 2
test_num_samples = 10
validation_num_samples = 10
input_shape = (248, 128, 128, 1)  # downscale 4
# input_shape = (488, 256, 256, 1)  # downscale 2
# input_shape = (964, 512, 512, 1)  # original

In [14]:
data_dir = Path("data")
tfrecord_fnames = [
    str(p)
    for g in (
        data_dir.glob("tcia-0.25/*.tfrecord"),
        data_dir.glob("nrrd-0.25/*.tfrecord"),
    )
    for p in g
]

full_dataset = tf.data.TFRecordDataset(tfrecord_fnames)
full_dataset = full_dataset.map(
    example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
full_dataset = full_dataset.map(
    normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
full_dataset = full_dataset.map(
    lambda x: tf.expand_dims(x, axis=-1),
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
)

In [15]:
dataset = full_dataset.shuffle(buffer_size=32)

test_dataset = dataset.take(test_num_samples)
test_dataset = test_dataset.batch(1)
dataset = dataset.skip(test_num_samples)

val_dataset = dataset.take(validation_num_samples)
val_dataset = val_dataset.padded_batch(batch_size=batch_size, padded_shapes=input_shape,)

train_dataset = dataset.skip(validation_num_samples)
train_dataset = train_dataset.padded_batch(batch_size=batch_size, padded_shapes=input_shape,)
train_dataset = train_dataset.cache()  # must be called before shuffle
train_dataset = train_dataset.shuffle(buffer_size=64, reshuffle_each_iteration=True)
train_dataset = train_dataset.take(16)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
train_dataset

<PrefetchDataset shapes: (None, 248, 128, 128, 1), types: tf.float32>

In [16]:
def benchmark(dataset, num_epochs=4):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    tf.print("Execution time:", time.perf_counter() - start_time)


#benchmark(train_dataset)

In [17]:
def conv_block(x, filters, kernel_size=3, dropout_rate=0.1, pool_size=2):
    """
    - Convolution 3D
    - Selu activation
    - Dropout
    - Max pool 3D
    
    x is the input layer using the Keras Functional API
    """
    x = keras.layers.Conv3D(
        filters=filters,
        kernel_size=kernel_size,
        padding="same",
        kernel_initializer="lecun_normal",
        bias_initializer="lecun_normal",
        activation="selu",
    )(x)
    x = keras.layers.AlphaDropout(dropout_rate)(x)
    x = keras.layers.MaxPool3D(pool_size=pool_size)(x)
    return x

In [18]:
def deconv_block(x, filters, kernel_size=3, dropout_rate=0.1, pool_size=2):
    """
    - Up sampling 3D
    - Convolution 3D
    - Selu activation
    - Dropout
    
    x is the input layer using the Keras Functional  API
    """
    x = keras.layers.UpSampling3D(size=pool_size)(x)
    x = keras.layers.Conv3D(
        filters=filters,
        kernel_size=kernel_size,
        padding="same",
        kernel_initializer="lecun_normal",
        bias_initializer="lecun_normal",
        activation="selu",
    )(x)
    x = keras.layers.AlphaDropout(dropout_rate)(x)
    return x

In [19]:
encoder_inputs = keras.Input(input_shape)
x = conv_block(encoder_inputs, filters=16)
x = conv_block(x, filters=32)
encoder_outputs = conv_block(x, filters=64)
encoder = keras.Model(encoder_inputs, encoder_outputs, name="encoder")
encoder.summary()

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 248, 128, 128, 1) 0         
_________________________________________________________________
conv3d_6 (Conv3D)            (None, 248, 128, 128, 16) 448       
_________________________________________________________________
alpha_dropout_6 (AlphaDropou (None, 248, 128, 128, 16) 0         
_________________________________________________________________
max_pooling3d_3 (MaxPooling3 (None, 124, 64, 64, 16)   0         
_________________________________________________________________
conv3d_7 (Conv3D)            (None, 124, 64, 64, 32)   13856     
_________________________________________________________________
alpha_dropout_7 (AlphaDropou (None, 124, 64, 64, 32)   0         
_________________________________________________________________
max_pooling3d_4 (MaxPooling3 (None, 62, 32, 32, 32)    0   

In [20]:
decoder_inputs = keras.Input(encoder.output_shape[1:])
x = deconv_block(decoder_inputs, filters=64)
x = deconv_block(x, filters=32)
x = deconv_block(x, filters=16)
decoder_outputs = keras.layers.Dense(1, activation="sigmoid")(x)
decoder = keras.Model(decoder_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 31, 16, 16, 64)]  0         
_________________________________________________________________
up_sampling3d_3 (UpSampling3 (None, 62, 32, 32, 64)    0         
_________________________________________________________________
conv3d_9 (Conv3D)            (None, 62, 32, 32, 64)    110656    
_________________________________________________________________
alpha_dropout_9 (AlphaDropou (None, 62, 32, 32, 64)    0         
_________________________________________________________________
up_sampling3d_4 (UpSampling3 (None, 124, 64, 64, 64)   0         
_________________________________________________________________
conv3d_10 (Conv3D)           (None, 124, 64, 64, 32)   55328     
_________________________________________________________________
alpha_dropout_10 (AlphaDropo (None, 124, 64, 64, 32)   0   

In [21]:
autoencoder = keras.models.Sequential([encoder, decoder], name="autoencoder")
# autoencoder.load_weights("models/autoencoder/20200924-235155/best_epoch_ckpt")
# autoencoder = keras.models.load_model("models/autoencoder/20200924-235155")

assert autoencoder.output_shape[1:] == input_shape
autoencoder.summary()

Model: "autoencoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Functional)         (None, 31, 16, 16, 64)    69664     
_________________________________________________________________
decoder (Functional)         (None, 248, 128, 128, 1)  179841    
Total params: 249,505
Trainable params: 249,505
Non-trainable params: 0
_________________________________________________________________


In [22]:
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(lr=learning_rate)

In [23]:
start_time = datetime.datetime.now()
start_time_str = start_time.strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/experiments/{start_time_str}/"
model_dir = f"models/experiments/{start_time_str}/"
ckpt_dir = model_dir + "best_epoch_ckpt"
writer = tf.summary.create_file_writer(log_dir)

early_stopping = EarlyStopping(patience)

for epoch in tqdm(range(epochs), disable=False):

    ### TRAIN ###

    train_loss_metric = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
    for batch in train_dataset:
        with tf.GradientTape() as tape:
            predictions = autoencoder(batch)
            loss_value = loss_fn(predictions, batch)
        gradients = tape.gradient(loss_value, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
        train_loss_metric.update_state(loss_value)
        with writer.as_default():
            for grad, param in zip(gradients, autoencoder.trainable_variables):
                tf.summary.histogram(param.name, param, step=epoch)
                # tf.summary.histogram(param.name + "/grad", grad, buckets=1, step=epoch)

    train_loss_mean = train_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Training loss", train_loss_mean, step=epoch)
    train_loss_metric.reset_states()

    ### VALIDATION ###

    val_loss_metric = tf.keras.metrics.Mean("val_loss", dtype=tf.float32)
    for batch in val_dataset:
        predictions = autoencoder(batch)
        val_loss_metric.update_state(loss_fn(predictions, batch))

    val_loss_mean = val_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Validation loss", val_loss_mean, step=epoch)
    val_loss_metric.reset_states()

    if verbose_training:
        print()
        print(f"Epoch : {epoch}")
        print(f"Training loss: {train_loss_mean}")
        print(f"Validation loss: {val_loss_mean}")

    ### EARLY STOPPING ###

    early_stopping.update(val_loss_mean)
    if early_stopping.early_stop:
        autoencoder.load_weights(ckpt_dir)
        break
    elif early_stopping.not_improving_epochs == 0:
        autoencoder.save_weights(ckpt_dir)

autoencoder.save(model_dir)

end_time = datetime.datetime.now()
training_time = str(end_time - start_time).split(".")[0]

with writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"batch size = {batch_size}; "
        f"patience = {patience}; "
        f"learning rate = {learning_rate}; "
        f"training time = {training_time}",
        step=0,
    )

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


Epoch : 0
Training loss: 0.04380191117525101
Validation loss: 0.01683100126683712



In [24]:
%reload_ext tensorboard
%tensorboard --logdir=logs --bind_all

Reusing TensorBoard on port 6006 (pid 21841), started 1:34:14 ago. (Use '!kill 21841' to kill it.)

In [None]:
original = next(iter(test_dataset))
encoder_out = autoencoder.layers[0](original)
decoder_out = autoencoder.layers[1](encoder_out)
batch_index = 0

In [None]:
z_index = 20
fig, ax = plt.subplots(ncols=3)
plot_slice(original, batch_index, z_index, ax[0])
plot_slice(encoder_out, batch_index, encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out, batch_index, z_index, ax[2])

In [None]:
plot_animated_volume(original, batch_index)

In [None]:
plot_animated_volume(encoder_out, batch_index, fps=10)

In [None]:
plot_animated_volume(decoder_out, batch_index)