In [1]:
import datetime
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from data import example_to_tensor, normalize
from train import EarlyStopping
from utils import plot_slice, plot_animated_volume

print(f"Tensorflow: {tf.__version__}")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

Tensorflow: 2.3.0


In [2]:
verbose_training = True
# Hyperparameters
epochs = 1000
learning_rate = 0.0005
patience = 20
batch_size = 2
test_size = 2  # number of images
validation_size = 2  # number of batches
input_shape = (248, 128, 128, 1)  # downscale 4
# input_shape = (488, 256, 256, 1)  # downscale 2
# input_shape = (964, 512, 512, 1)  # original

In [3]:
data_dir = Path("data")
tfrecord_fnames = [
    str(p)
    for g in (
        data_dir.glob("tcia-0.25/*.tfrecord"),
        data_dir.glob("nrrd-0.25/*.tfrecord"),
    )
    for p in g
]

full_dataset = tf.data.TFRecordDataset(tfrecord_fnames)
full_dataset = full_dataset.map(example_to_tensor)
full_dataset = full_dataset.map(normalize)
full_dataset = full_dataset.map(
    lambda x: tf.expand_dims(x, axis=-1)
)  # add the channel dimension

In [4]:
dataset = full_dataset
test_dataset = dataset.take(test_size)
test_dataset = test_dataset.batch(1)
dataset = dataset.skip(test_size)
dataset = dataset.padded_batch(batch_size=batch_size, padded_shapes=input_shape,)
val_dataset = dataset.take(validation_size)
train_dataset = dataset.skip(validation_size)
train_dataset = train_dataset.shuffle(buffer_size=64, reshuffle_each_iteration=True)
train_dataset = train_dataset.take(10)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
train_dataset

<PrefetchDataset shapes: (None, 248, 128, 128, 1), types: tf.float32>

In [5]:
def conv_block(x, filters, kernel_size=3, dropout_rate=0.1, pool_size=2):
    """
    - Convolution 3D
    - Selu activation
    - Dropout
    - Max pool 3D
    
    x is the input layer using the Keras Functional API
    """
    x = keras.layers.Conv3D(
        filters=filters,
        kernel_size=kernel_size,
        padding="same",
        kernel_initializer="lecun_normal",
        bias_initializer="lecun_normal",
        activation="selu",
    )(x)
    x = keras.layers.AlphaDropout(dropout_rate)(x)
    x = keras.layers.MaxPool3D(pool_size=pool_size)(x)
    return x

In [6]:
def deconv_block(x, filters, kernel_size=3, dropout_rate=0.1, pool_size=2):
    """
    - Up sampling 3D
    - Convolution 3D
    - Selu activation
    - Dropout
    
    x is the input layer using the Keras Functional  API
    """
    x = keras.layers.UpSampling3D(size=pool_size)(x)
    x = keras.layers.Conv3D(
            filters=filters,
            kernel_size=kernel_size,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
            activation="selu"
        )(x)
    x = keras.layers.AlphaDropout(dropout_rate)(x)
    return x

In [7]:
encoder_inputs = keras.Input(input_shape)
x = conv_block(encoder_inputs, filters=16)
x = conv_block(x, filters=32)
encoder_outputs = conv_block(x, filters=64)
encoder = keras.Model(encoder_inputs, encoder_outputs, name="encoder")
encoder.summary()

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 248, 128, 128, 1) 0         
_________________________________________________________________
conv3d (Conv3D)              (None, 248, 128, 128, 16) 448       
_________________________________________________________________
alpha_dropout (AlphaDropout) (None, 248, 128, 128, 16) 0         
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 124, 64, 64, 16)   0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 124, 64, 64, 32)   13856     
_________________________________________________________________
alpha_dropout_1 (AlphaDropou (None, 124, 64, 64, 32)   0         
_________________________________________________________________
max_pooling3d_1 (MaxPooling3 (None, 62, 32, 32, 32)    0   

In [21]:
encoder.layers[1].weights

[<tf.Variable 'conv3d/kernel:0' shape=(3, 3, 3, 1, 16) dtype=float32, numpy=
 array([[[[[-0.24914922, -0.06729963,  0.3083456 , -0.11570771,
             0.01487178, -0.253089  ,  0.37422928, -0.07756257,
             0.12920101, -0.3742187 ,  0.3637306 ,  0.31659406,
             0.05248392, -0.02182407,  0.17800558,  0.02088513]],
 
          [[-0.2701462 ,  0.16340272, -0.04656548, -0.14909935,
            -0.1578288 , -0.03961995,  0.12116717, -0.09932431,
             0.24974833,  0.3894202 , -0.0832682 ,  0.20334795,
             0.29502818, -0.01803628,  0.13189994,  0.18931863]],
 
          [[-0.05091888,  0.09065039,  0.17710423,  0.16874634,
             0.25140384,  0.02985234,  0.3027352 , -0.10572255,
             0.08837077,  0.18199475, -0.0925707 ,  0.3712463 ,
            -0.22585873,  0.3458302 ,  0.02183132,  0.02729191]]],
 
 
         [[[-0.15410095,  0.37810928,  0.10806335, -0.06271126,
             0.13017185, -0.25226414,  0.15698378,  0.08408548,
            

In [9]:
decoder_inputs = keras.Input(encoder.output_shape[1:])
x = deconv_block(decoder_inputs, filters=64)
x = deconv_block(x, filters=32)
x = deconv_block(x, filters=16)
decoder_outputs = keras.layers.Dense(1, activation="sigmoid")(x)
decoder = keras.Model(decoder_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 31, 16, 16, 64)]  0         
_________________________________________________________________
up_sampling3d_3 (UpSampling3 (None, 62, 32, 32, 64)    0         
_________________________________________________________________
conv3d_6 (Conv3D)            (None, 62, 32, 32, 64)    110656    
_________________________________________________________________
activation_6 (Activation)    (None, 62, 32, 32, 64)    0         
_________________________________________________________________
alpha_dropout_6 (AlphaDropou (None, 62, 32, 32, 64)    0         
_________________________________________________________________
up_sampling3d_4 (UpSampling3 (None, 124, 64, 64, 64)   0         
_________________________________________________________________
conv3d_7 (Conv3D)            (None, 124, 64, 64, 32)   5532

In [16]:
autoencoder = keras.models.Sequential([encoder, decoder])
# autoencoder.load_weights("models/autoencoder/20200924-235155/best_epoch_ckpt")
# autoencoder = keras.models.load_model("models/autoencoder/20200924-235155")

assert autoencoder.output_shape[1:] == input_shape
autoencoder.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Functional)         (None, 31, 16, 16, 64)    69664     
_________________________________________________________________
decoder (Functional)         (None, 248, 128, 128, 1)  179841    
Total params: 249,505
Trainable params: 249,505
Non-trainable params: 0
_________________________________________________________________


In [None]:
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(lr=learning_rate)

In [None]:
start_time = datetime.datetime.now()
start_time_str = start_time.strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/autoencoder/{start_time_str}/"
model_dir = f"models/autoencoder/{start_time_str}/"
ckpt_dir = model_dir + "best_epoch_ckpt"
writer = tf.summary.create_file_writer(log_dir)

early_stopping = EarlyStopping(patience)

for epoch in tqdm(range(epochs), disable=False):

    ### TRAIN ###

    train_loss_metric = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
    for batch in train_dataset:
        with tf.GradientTape() as tape:
            predictions = autoencoder(batch)
            loss_value = loss_fn(predictions, batch)
        gradients = tape.gradient(loss_value, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
        train_loss_metric.update_state(loss_value)
        with writer.as_default():
            for grad, param in zip(gradients, autoencoder.trainable_variables):
                tf.summary.histogram(param.name, param, step=epoch)
                # tf.summary.histogram(param.name + "/grad", grad, buckets=1, step=epoch)

    train_loss_mean = train_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Training loss", train_loss_mean, step=epoch)
    train_loss_metric.reset_states()

    ### VALIDATION ###

    val_loss_metric = tf.keras.metrics.Mean("val_loss", dtype=tf.float32)
    for batch in val_dataset:
        predictions = autoencoder(batch)
        val_loss_metric.update_state(loss_fn(predictions, batch))

    val_loss_mean = val_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Validation loss", val_loss_mean, step=epoch)
    val_loss_metric.reset_states()

    if verbose_training:
        print()
        print(f"Epoch : {epoch}")
        print(f"Training loss: {train_loss_mean}")
        print(f"Validation loss: {val_loss_mean}")

    ### EARLY STOPPING ###

    early_stopping.update(val_loss_mean)
    if early_stopping.early_stop:
        autoencoder.load_weights(ckpt_dir)
        break
    elif early_stopping.not_improving_epochs == 0:
        autoencoder.save_weights(ckpt_dir)

autoencoder.save(model_dir)

end_time = datetime.datetime.now()
training_time = str(end_time - start_time).split(".")[0]

with writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"batch size = {batch_size}; "
        f"patience = {patience}; "
        f"learning rate = {learning_rate}; "
        f"training time = {training_time}",
        step=0,
    )

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=logs --bind_all

In [None]:
original = next(iter(test_dataset))
encoder_out = autoencoder.layers[0](original)
decoder_out = autoencoder.layers[1](encoder_out)
batch_index = 0

In [None]:
z_index = 20
fig, ax = plt.subplots(ncols=3)
plot_slice(original, batch_index, z_index, ax[0])
plot_slice(encoder_out, batch_index, encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out, batch_index, z_index, ax[2])

In [None]:
plot_animated_volume(original, batch_index)

In [None]:
plot_animated_volume(encoder_out, batch_index, fps=10)

In [None]:
plot_animated_volume(decoder_out, batch_index)