In [1]:
import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from model import build_encoder, build_autoencoder
from data import example_to_tensor, normalize, add_channel_axis, train_test_split
from utils import duplicate_iterator, plot_slice, plot_animated_volume
from config import allocate_gpu_memory_only_when_needed, data_root_dir, seed

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [2]:
# from tensorflow.keras.mixed_precision import experimental as mixed_precision
#
# policy = mixed_precision.Policy("mixed_float16")
##policy = mixed_precision.Policy("float32")
# mixed_precision.set_policy(policy)
# print("Compute dtype: %s" % policy.compute_dtype)
# print("Variable dtype: %s" % policy.variable_dtype)

In [3]:
downscaling = 2
if downscaling == 4:
    input_shape = (24, 128, 128, 1)
    tfrecord_glob = "CT-[0-4]-0.25/*.tfrecord"
elif downscaling == 2:
    input_shape = (48, 256, 256, 1)
    tfrecord_glob = "CT-[0-4]-0.5/*.tfrecord"
elif downscaling == 1:
    input_shape = (96, 512, 512, 1)
    tfrecord_glob = "CT-[0-4]/*.tfrecord"
else:
    raise RuntimeError("Downscaling not supported")

encoder_num_filters = [32, 64, 128]
epochs = 1000
patience = 20
batch_size = 8
learning_rate = 0.0001
val_perc = 0.2

In [4]:
tfrecord_fnames = [str(p) for p in Path(data_root_dir).glob(tfrecord_glob)]
dataset = (
    tf.data.TFRecordDataset(tfrecord_fnames)
    .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(add_channel_axis, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)
num_samples = sum(1 for _ in dataset)
# num_samples = 1110   # CT-[0-4]
print(f"Number of samples: {num_samples}")
dataset

Number of samples: 1110


<ParallelMapDataset shapes: (None, None, None, 1), types: tf.float32>

In [8]:
unsupervised_train_ds, unsupervised_val_ds = train_test_split(
    dataset,
    test_perc=val_perc,
    cardinality=num_samples,
    seed=seed,
)
unsupervised_val_ds = unsupervised_val_ds.padded_batch(batch_size, input_shape)
val_ds_gen = duplicate_iterator(unsupervised_val_ds)
unsupervised_val_ds = (
    tf.data.Dataset.from_generator(lambda: val_ds_gen, (tf.float32, tf.float32))
    .cache()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
unsupervised_train_ds = unsupervised_train_ds.padded_batch(batch_size, input_shape)
train_ds_gen = duplicate_iterator(unsupervised_train_ds)
unsupervised_train_ds = (
    tf.data.Dataset.from_generator(lambda: train_ds_gen, (tf.float32, tf.float32))
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
unsupervised_train_ds

<PrefetchDataset shapes: (<unknown>, <unknown>), types: (tf.float32, tf.float32)>

In [None]:
autoencoder = build_autoencoder(input_shape, encoder_num_filters)
autoencoder.get_layer("encoder").summary()
autoencoder.get_layer("decoder").summary()
autoencoder.summary()

In [None]:
autoencoder.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.MeanSquaredError(),
)

start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
best_checkpoint = f"models/autoencoder-{start_time}.h5"
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    best_checkpoint, monitor="val_loss", verbose=1, save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=patience,
)
log_dir = f"logs/autoencoder-{start_time}"
file_writer = tf.summary.create_file_writer(log_dir)
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=False,
    profile_batch=0,
)
autoencoder.fit(
    unsupervised_train_ds,
    validation_data=unsupervised_val_ds,
    epochs=epochs,
    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb],
)
with file_writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"{downscaling=}; "
        f"{encoder_num_filters=}; "
        f"{epochs=}; "
        f"{patience=}; "
        f"{batch_size=}; "
        f"{learning_rate=}; "
        f"{unsupervised_val_perc=}",
        step=0,
    )
autoencoder = keras.models.load_model(best_checkpoint)

In [None]:
autoencoder = keras.models.load_model("models/autoencoder-20201022-180425.h5")
original, _ = next(iter(unsupervised_val_ds.skip(4)))
encoder_out = autoencoder.get_layer("encoder")(original, training=False)
decoder_out = autoencoder.get_layer("decoder")(encoder_out, training=False)
batch_index = 0

In [None]:
plot_animated_volume(original[0, :], fps=2)

In [None]:
z_index = 13
fig, ax = plt.subplots(ncols=3)
plot_slice(original[batch_index, :], z_index, ax[0])
plot_slice(encoder_out[batch_index, :], encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out[batch_index, :], z_index, ax[2])

In [None]:
pretrained = False
if pretrained:
    encoder = keras.models.load_model(
        "models/autoencoder-20201023-112618.h5"
    ).get_layer("encoder")
    encoder.trainable = False
else:
    encoder = build_encoder(input_shape)
encoder.summary()

In [None]:
cnn = keras.Sequential(
    [
        encoder,
        keras.layers.Flatten(),
        keras.layers.Dense(
            512,
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
            activation="selu",
        ),
        keras.layers.AlphaDropout(0.3),
        keras.layers.Dense(1, activation="sigmoid"),
    ],
    name="cnn",
)
cnn.summary()

In [None]:
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[
        keras.metrics.TruePositives(name="tp"),
        keras.metrics.FalsePositives(name="fp"),
        keras.metrics.TrueNegatives(name="tn"),
        keras.metrics.FalseNegatives(name="fn"),
        keras.metrics.BinaryAccuracy(name="accuracy"),
        keras.metrics.Precision(name="precision"),
        keras.metrics.Recall(name="recall"),
        keras.metrics.AUC(name="auc"),
    ],
)

start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
best_checkpoint = f"models/{'pretrained-' if pretrained else ''}{start_time}.h5"
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    best_checkpoint, monitor="val_auc", mode="max", verbose=1, save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_auc", patience=patience, mode="max"
)
log_dir = f"logs/{'pretrained-' if pretrained else ''}{start_time}"
file_writer = tf.summary.create_file_writer(log_dir)
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=f"logs/{'pretrained-' if pretrained else ''}{start_time}",
    histogram_freq=1,
    write_graph=False,
    profile_batch=0,
)
cnn.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb],
)
with file_writer.as_default():
    tf.summary.text(
        "Hyperparameters",
        f"{downscaling=}; "
        f"{encoder_num_filters=}; "
        f"{epochs=}; "
        f"{patience=}; "
        f"{batch_size=}; "
        f"{learning_rate=}; "
        f"{unsupervised_val_perc=}",
        step=0,
    )
cnn = keras.models.load_model(best_checkpoint)

In [None]:
cnn = keras.models.load_model("models/20201021-213411.h5")
cnn.evaluate(test_dataset, verbose=0, return_dict=True)

In [None]:
cnn = keras.models.load_model("models/pretrained-20201021-230015.h5")
cnn.evaluate(test_dataset, verbose=0, return_dict=True)

In [None]:
x, y = next(iter(test_dataset.skip(6)))
prediction = cnn(x, training=False)
print(f"real: {y.numpy()}, prediction: {prediction.numpy()}")
plot_animated_volume(x[0, :], fps=1)