In [1]:
import datetime
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from data import example_to_tensor
from train import EarlyStopping
from utils import plot_slice, plot_animated_volume

print(f"Tensorflow: {tf.__version__}")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

Tensorflow: 2.3.0


In [2]:
verbose_training = True
# Hyperparameters
epochs = 1000
learning_rate = 0.0001
patience = 5
batch_size = 2
test_size = 1
validation_size = 1
xy_size = 128  # downscale 4
#z_size = 96  # downscale 4 (nrrd)
z_size = 244  # downscale 4 (tcia)
# xy_size = 256  # downscale 2
# z_size = 176  # downscale 2
# xy_size = 512   # original
# z_size = 368    # original

In [3]:
def normalize(t):
    "Normalize the input tensor in [0, 1]"
    max_value = tf.reduce_max(t)
    min_value = tf.reduce_min(t)
    return (t - min_value) / (max_value - min_value)

In [16]:
data_dir = Path("data/tcia-0.25")
tfrecord_fnames = [str(p) for p in data_dir.glob("*.tfrecord")]
dataset = tf.data.TFRecordDataset(tfrecord_fnames)
dataset = dataset.map(example_to_tensor)
dataset = dataset.map(normalize)
dataset = dataset.map(lambda x: tf.expand_dims(x, axis=-1))  # add the channel dimension
dataset = dataset.padded_batch(
    batch_size=2, 
    padded_shapes=[z_size, xy_size, xy_size, 1],
)
#dataset = dataset.prefetch(1)
test_dataset = dataset.take(test_size)
dataset = dataset.skip(test_size)
val_dataset = dataset.take(validation_size)
train_dataset = dataset.skip(validation_size)
train_dataset = train_dataset.shuffle(buffer_size=32, reshuffle_each_iteration=True)
train_dataset

<ShuffleDataset shapes: (None, 244, 128, 128, 1), types: tf.float32>

In [17]:
encoder = keras.models.Sequential(
    [
        keras.layers.Conv3D(
            input_shape=[z_size, xy_size, xy_size, 1],
            filters=8,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
        keras.layers.Conv3D(
            filters=16,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
    ]
)
encoder.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d_4 (Conv3D)            (None, 244, 128, 128, 8)  224       
_________________________________________________________________
activation_5 (Activation)    (None, 244, 128, 128, 8)  0         
_________________________________________________________________
alpha_dropout_4 (AlphaDropou (None, 244, 128, 128, 8)  0         
_________________________________________________________________
max_pooling3d_2 (MaxPooling3 (None, 122, 64, 64, 8)    0         
_________________________________________________________________
conv3d_5 (Conv3D)            (None, 122, 64, 64, 16)   3472      
_________________________________________________________________
activation_6 (Activation)    (None, 122, 64, 64, 16)   0         
_________________________________________________________________
alpha_dropout_5 (AlphaDropou (None, 122, 64, 64, 16)  

In [18]:
decoder = keras.models.Sequential(
    [
        keras.layers.UpSampling3D(
            input_shape=encoder.layers[-1].output.shape[1:], size=2,
        ),
        keras.layers.Conv3D(
            filters=16,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.UpSampling3D(size=2,),
        keras.layers.Conv3D(
            filters=8,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.Dense(1),
        keras.layers.Activation("sigmoid"),
    ]
)
decoder.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
up_sampling3d_2 (UpSampling3 (None, 122, 64, 64, 16)   0         
_________________________________________________________________
conv3d_6 (Conv3D)            (None, 122, 64, 64, 16)   6928      
_________________________________________________________________
activation_7 (Activation)    (None, 122, 64, 64, 16)   0         
_________________________________________________________________
alpha_dropout_6 (AlphaDropou (None, 122, 64, 64, 16)   0         
_________________________________________________________________
up_sampling3d_3 (UpSampling3 (None, 244, 128, 128, 16) 0         
_________________________________________________________________
conv3d_7 (Conv3D)            (None, 244, 128, 128, 8)  3464      
_________________________________________________________________
activation_8 (Activation)    (None, 244, 128, 128, 8) 

In [19]:
autoencoder = keras.models.Sequential([encoder, decoder])
# autoencoder.load_weights("models/autoencoder/20200723-103317/best_epoch_ckpt")
# autoencoder = keras.models.load_model("models/autoencoder/20200823-110737/")
autoencoder.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_3 (Sequential)    (None, 61, 32, 32, 16)    3696      
_________________________________________________________________
sequential_4 (Sequential)    (None, 244, 128, 128, 1)  10401     
Total params: 14,097
Trainable params: 14,097
Non-trainable params: 0
_________________________________________________________________


In [20]:
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(lr=learning_rate)

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/autoencoder/{current_time}/"
model_dir = f"models/autoencoder/{current_time}/"
ckpt_dir = model_dir + "best_epoch_ckpt"
writer = tf.summary.create_file_writer(log_dir)
early_stopping = EarlyStopping(patience)

for epoch in tqdm(range(epochs), disable=False):
    
    ### TRAIN ###
    
    train_loss_metric = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
    for batch in train_dataset:
        with tf.GradientTape() as tape:
            predictions = autoencoder(batch)
            loss_value = loss_fn(predictions, batch)
        gradients = tape.gradient(loss_value, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
        train_loss_metric.update_state(loss_value)
        #with writer.as_default():
        #    for grad, param in zip(gradients, autoencoder.trainable_variables):
        #        tf.summary.histogram(param.name, param, step=epoch)
        #        # tf.summary.histogram(param.name + "/grad", grad, buckets=1, step=epoch)

    train_loss_mean = train_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Training loss", train_loss_mean, step=epoch)
    train_loss_metric.reset_states()
   
   ### VALIDATION ###

    val_loss_metric = tf.keras.metrics.Mean("val_loss", dtype=tf.float32)
    for batch in val_dataset:
        predictions = autoencoder(batch)
        val_loss_metric.update_state(loss_fn(predictions, batch))

    val_loss_mean = val_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Validation loss", val_loss_mean, step=epoch)
    val_loss_metric.reset_states()

    if verbose_training:
        print()
        print(f"Epoch : {epoch}")
        print(f"Training loss: {train_loss_mean}")
        print(f"Validation loss: {val_loss_mean}")
    
    ### EARLY STOPPING ###
    
    early_stopping.update(val_loss_mean)
    if early_stopping.early_stop:
        autoencoder.load_weights(ckpt_dir)
        break
    elif early_stopping.not_improving_epochs == 0:
        autoencoder.save_weights(ckpt_dir)

autoencoder.save(model_dir)

with writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"batch size = {batch_size}; "
        f"patience = {patience}; "
        f"learning rate = {learning_rate}",
        step=0,
    )

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Epoch : 0
Training loss: 0.009702341631054878
Validation loss: 0.0018241752404719591

Epoch : 1
Training loss: 0.0015542582841590047
Validation loss: 0.0012430717470124364

Epoch : 2
Training loss: 0.0011069163447245955
Validation loss: 0.0010182594414800406

Epoch : 3
Training loss: 0.0009315984207205474
Validation loss: 0.0009205001988448203

Epoch : 4
Training loss: 0.0008268418605439365
Validation loss: 0.0008787276456132531


In [None]:
%reload_ext tensorboard
%tensorboard --logdir=logs --bind_all

In [None]:
scan = next(iter(dataset))
prediction = autoencoder(scan)
plot_slice(prediction, 0, 5)

In [None]:
encoder_input = next(iter(test_dataset))
fig, ax = plt.subplots(ncols=3)
batch_index = 0
z_index = 30
plot_slice(encoder_input, batch_index, z_index, ax[0])
encoder_out = autoencoder.layers[0](encoder_input)
plot_slice(encoder_out, batch_index, encoder_out.shape[1] // 3, ax[1])
decoder_out = autoencoder.layers[1](encoder_out)
plot_slice(decoder_out, batch_index, z_index, ax[2])
#ax[0].imshow(tf.cast(i[sample, 0, :, :, 0], tf.float32), cmap="gray")
#encoder_out = autoencoder.layers[0](i)
#ax[1].imshow(tf.cast(encoder_out[sample, 0, :, :, 0], tf.float32), cmap="gray")
#decoder_out = autoencoder.layers[1](encoder_out)
#ax[2].imshow(tf.cast(decoder_out[sample, 0, :, :, 0], tf.float32), cmap="gray")