In [1]:
import datetime
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from data import example_to_tensor
from train import EarlyStopping
from utils import plot_slice, plot_animated_volume

print(f"Tensorflow: {tf.__version__}")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

Tensorflow: 2.3.0


In [2]:
verbose_training = True
# Hyperparameters
epochs = 1000
learning_rate = 0.0001
patience = 20
batch_size = 2
test_size = 4
validation_size = 4
xy_size = 128  # downscale 4
#z_size = 96  # downscale 4 (nrrd)
z_size = 248  # downscale 4 (tcia)
#xy_size = 256  # downscale 2
# z_size = 176  # downscale 2 (nrrd)
#z_size = 488  # downscale 2 (tcia)
# xy_size = 512   # original
# z_size = 368    # original

In [3]:
def normalize(t):
    "Normalize the input tensor in [0, 1]"
    max_value = tf.reduce_max(t)
    min_value = tf.reduce_min(t)
    return (t - min_value) / (max_value - min_value)

In [4]:
data_dir = Path("data")
tfrecord_fnames = [str(p) for p in data_dir.glob("*-0.25/*.tfrecord")]

full_dataset = tf.data.TFRecordDataset(tfrecord_fnames)
full_dataset = full_dataset.map(example_to_tensor)
full_dataset = full_dataset.map(normalize)
full_dataset = full_dataset.map(lambda x: tf.expand_dims(x, axis=-1))  # add the channel dimension

In [5]:
dataset = full_dataset
test_dataset = dataset.take(test_size)
test_dataset = test_dataset.batch(1)
dataset = dataset.skip(test_size)
dataset = dataset.padded_batch(
    batch_size=batch_size, 
    padded_shapes=[z_size, xy_size, xy_size, 1],
)
val_dataset = dataset.take(validation_size)
train_dataset = dataset.skip(validation_size)
train_dataset = train_dataset.shuffle(
    buffer_size=64, 
    reshuffle_each_iteration=True
)
train_dataset = train_dataset.take(10)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
train_dataset

<PrefetchDataset shapes: (None, 248, 128, 128, 1), types: tf.float32>

In [6]:
encoder = keras.models.Sequential(
    [
        keras.layers.Conv3D(
            input_shape=[z_size, xy_size, xy_size, 1],
            filters=32,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
        keras.layers.Conv3D(
            filters=64,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.MaxPool3D(pool_size=2),
    ]
)
encoder.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d (Conv3D)              (None, 248, 128, 128, 32) 896       
_________________________________________________________________
activation (Activation)      (None, 248, 128, 128, 32) 0         
_________________________________________________________________
alpha_dropout (AlphaDropout) (None, 248, 128, 128, 32) 0         
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 124, 64, 64, 32)   0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 124, 64, 64, 64)   55360     
_________________________________________________________________
activation_1 (Activation)    (None, 124, 64, 64, 64)   0         
_________________________________________________________________
alpha_dropout_1 (AlphaDropou (None, 124, 64, 64, 64)   0

In [7]:
decoder = keras.models.Sequential(
    [
        keras.layers.UpSampling3D(
            input_shape=encoder.layers[-1].output.shape[1:], size=2,
        ),
        keras.layers.Conv3D(
            filters=64,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.UpSampling3D(size=2,),
        keras.layers.Conv3D(
            filters=32,
            kernel_size=3,
            padding="same",
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
        ),
        keras.layers.Activation("selu"),
        keras.layers.AlphaDropout(0.25),
        keras.layers.Dense(1),
        keras.layers.Activation("sigmoid"),
    ]
)
decoder.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
up_sampling3d (UpSampling3D) (None, 124, 64, 64, 64)   0         
_________________________________________________________________
conv3d_2 (Conv3D)            (None, 124, 64, 64, 64)   110656    
_________________________________________________________________
activation_2 (Activation)    (None, 124, 64, 64, 64)   0         
_________________________________________________________________
alpha_dropout_2 (AlphaDropou (None, 124, 64, 64, 64)   0         
_________________________________________________________________
up_sampling3d_1 (UpSampling3 (None, 248, 128, 128, 64) 0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 248, 128, 128, 32) 55328     
_________________________________________________________________
activation_3 (Activation)    (None, 248, 128, 128, 32)

In [8]:
autoencoder = keras.models.Sequential([encoder, decoder])
# autoencoder.load_weights("models/autoencoder/20200924-235155/best_epoch_ckpt")
# autoencoder = keras.models.load_model("models/autoencoder/20200924-235155")

#strategy = tf.distribute.MirroredStrategy()
#print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
#
#with strategy.scope():
#    autoencoder = keras.models.Sequential([encoder, decoder])

autoencoder.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      (None, 62, 32, 32, 64)    56256     
_________________________________________________________________
sequential_1 (Sequential)    (None, 248, 128, 128, 1)  166017    
Total params: 222,273
Trainable params: 222,273
Non-trainable params: 0
_________________________________________________________________


In [9]:
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(lr=learning_rate)

In [12]:
start_time = datetime.datetime.now()
start_time_str = start_time.strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/autoencoder/{start_time_str}/"
model_dir = f"models/autoencoder/{start_time_str}/"
ckpt_dir = model_dir + "best_epoch_ckpt"
writer = tf.summary.create_file_writer(log_dir)

early_stopping = EarlyStopping(patience)

for epoch in tqdm(range(epochs), disable=False):
    
    ### TRAIN ###
    
    train_loss_metric = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
    for batch in train_dataset:
        with tf.GradientTape() as tape:
            predictions = autoencoder(batch)
            loss_value = loss_fn(predictions, batch)
        gradients = tape.gradient(loss_value, autoencoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))
        train_loss_metric.update_state(loss_value)
        #with writer.as_default():
        #    for grad, param in zip(gradients, autoencoder.trainable_variables):
        #        tf.summary.histogram(param.name, param, step=epoch)
        #        # tf.summary.histogram(param.name + "/grad", grad, buckets=1, step=epoch)

    train_loss_mean = train_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Training loss", train_loss_mean, step=epoch)
    train_loss_metric.reset_states()
   
   ### VALIDATION ###

    val_loss_metric = tf.keras.metrics.Mean("val_loss", dtype=tf.float32)
    #for batch in val_dataset:
    for batch in train_dataset:
        predictions = autoencoder(batch)
        val_loss_metric.update_state(loss_fn(predictions, batch))

    val_loss_mean = val_loss_metric.result()
    with writer.as_default():
        tf.summary.scalar("Validation loss", val_loss_mean, step=epoch)
    val_loss_metric.reset_states()

    if verbose_training:
        print()
        print(f"Epoch : {epoch}")
        print(f"Training loss: {train_loss_mean}")
        print(f"Validation loss: {val_loss_mean}")
    
    ### EARLY STOPPING ###
    
    early_stopping.update(val_loss_mean)
    if early_stopping.early_stop:
        autoencoder.load_weights(ckpt_dir)
        break
    elif early_stopping.not_improving_epochs == 0:
        autoencoder.save_weights(ckpt_dir)

autoencoder.save(model_dir)

end_time = datetime.datetime.now()
training_time = str(end_time - start_time).split(".")[0]

with writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"batch size = {batch_size}; "
        f"patience = {patience}; "
        f"learning rate = {learning_rate}; "
        f"training time = {training_time}",
        step=0,
    )

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


Epoch : 0
Training loss: 0.9216552972793579
Validation loss: 0.9412267804145813

Epoch : 1
Training loss: 0.9317477941513062
Validation loss: 0.9260936975479126

Epoch : 2
Training loss: 0.9209515452384949
Validation loss: 0.9180641174316406

Epoch : 3
Training loss: 0.932479202747345
Validation loss: 0.9426981806755066



In [13]:
%reload_ext tensorboard
%tensorboard --logdir=logs --bind_all

Reusing TensorBoard on port 6006 (pid 142900), started 6 days, 14:39:20 ago. (Use '!kill 142900' to kill it.)

In [None]:
original = next(iter(test_dataset))
encoder_out = autoencoder.layers[0](original)
decoder_out = autoencoder.layers[1](encoder_out)
batch_index = 0

In [None]:
z_index = 20
fig, ax = plt.subplots(ncols=3)
plot_slice(original, batch_index, z_index, ax[0])
plot_slice(encoder_out, batch_index, encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out, batch_index, z_index, ax[2])

In [None]:
plot_animated_volume(original, batch_index)

In [None]:
plot_animated_volume(encoder_out, batch_index, fps=10)

In [None]:
plot_animated_volume(decoder_out, batch_index)