In [1]:
import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from model import conv_block, deconv_block
from data import example_to_tensor, normalize, add_channel_axis, train_test_split
from plot import plot_slice, plot_volume_animation

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [2]:
input_shape = (48, 256, 256, 1)
# tfrecord_glob = "LUNA16/*.tfrecord"
tfrecord_glob = "covid-*/*.tfrecord"

encoder_filters = [32, 64, 128, 128]
epochs = 500
patience = 10
learning_rate = 0.00001
dropout_rate = 0.0
batch_size = 4
val_perc = 0.2

In [3]:
def min_max_normalize(scan):
    "Normalize the values in [0, 1]"
    min_value = tf.reduce_min(scan)
    max_value = tf.reduce_max(scan)
    return (scan - min_value) / (max_value - min_value)

In [4]:
tfrecord_fnames = [str(p) for p in Path("/pclhcb06/emilio/").glob(tfrecord_glob)]
dataset = (
    tf.data.TFRecordDataset(tfrecord_fnames)
    .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(add_channel_axis, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)
# num_samples = sum(1 for _ in dataset)
# num_samples = 1018  # LUNA16
num_samples = 500  # covid
print(f"Number of samples: {num_samples}")
dataset

Number of samples: 500


<ParallelMapDataset shapes: (None, None, None, 1), types: tf.float32>

In [5]:
next(iter(dataset))

<tf.Tensor: shape=(48, 256, 256, 1), dtype=float32, numpy=
array([[[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        ...,

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]],


       [[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
   

In [None]:
# duplicate the dataset to perform unsupervised training
duplicated_dataset = tf.data.Dataset.zip((dataset, dataset))
duplicated_dataset

In [None]:
train_dataset, val_dataset = train_test_split(
    duplicated_dataset,
    test_perc=val_perc,
    cardinality=num_samples,
    seed=seed,
)
val_dataset = (
    val_dataset.batch(batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset = (
    train_dataset.batch(batch_size)
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset

In [None]:
def build_and_compile_autoencoder(filters, dropout_rate, learning_rate):
    """Build the autoencoder with the specified number of filters.

    The decoder is a mirrored image of the encoder plus a dense layer.
    Compile the model with the Adam optimizer and MeanSquaredError loss.
    """
    encoder_inputs = keras.layers.Input(input_shape)
    x = encoder_inputs
    for f in filters:
        x = conv_block(x, filters=f, dropout_rate=dropout_rate)
    encoder_outputs = x
    encoder = keras.Model(encoder_inputs, encoder_outputs, name="encoder")

    decoder_inputs = keras.layers.Input(encoder.output_shape[1:])
    x = decoder_inputs
    for f in reversed(filters):
        x = deconv_block(x, filters=f, dropout_rate=dropout_rate)
    decoder_outputs = keras.layers.Dense(1, activation="sigmoid")(x)
    decoder = keras.Model(decoder_inputs, decoder_outputs, name="decoder")

    autoencoder = keras.Sequential([encoder, decoder], name="autoencoder")

    autoencoder.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
    )
    return autoencoder

In [None]:
autoencoder = build_and_compile_autoencoder(
    encoder_filters, dropout_rate, learning_rate
)
autoencoder.get_layer("encoder").summary()
autoencoder.get_layer("decoder").summary()
autoencoder.summary()

In [None]:
autoencoder = build_and_compile_autoencoder(
    encoder_filters, dropout_rate, learning_rate
)
monitor_metric = "val_loss"

start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
best_checkpoint = f"models/autoencoder-{start_time}.h5"
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    best_checkpoint, monitor=monitor_metric, verbose=1, save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor=monitor_metric,
    patience=patience,
)
log_dir = f"logs/autoencoder-{start_time}"
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"{input_shape=}; "
        f"{encoder_filters=}; "
        f"{epochs=}; "
        f"{patience=}; "
        f"{batch_size=}; "
        f"{dropout_rate=}; "
        f"{learning_rate=}; "
        f"{val_perc=}",
        step=0,
    )
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=False,
    profile_batch=0,
)
autoencoder.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb],
)


In [None]:
autoencoder = keras.models.load_model("models/autoencoder-20201029-125142.h5")
original, _ = next(iter(train_dataset.skip(1)))
encoder_out = autoencoder.get_layer("encoder")(original, training=False)
decoder_out = autoencoder.get_layer("decoder")(encoder_out, training=False)

In [None]:
batch_index = 3
z_index = 20
fig, ax = plt.subplots(ncols=3)
plot_slice(original[batch_index, :], z_index, ax[0])
plot_slice(encoder_out[batch_index, :], encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out[batch_index, :], z_index, ax[2])

In [None]:
plot_animated_volume(original[0, :])