In [1]:
import datetime
from functools import partial

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

from utils import example_to_tensor, train_test_split
from data_augmentation import random_rotate, random_flip
from plot import plot_slice, plot_volume_animation
from config import (
    SMALL_NEG_TFRECORD,
    SMALL_POS_TFRECORD,
    BIG_NEG_TFRECORD,
    BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
    SEED,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [2]:
def normalize(volume):
    "Normalize the input volume with values in [0, 1]"
    min_value = tf.reduce_min(volume)
    max_value = tf.reduce_max(volume)
    return (volume - min_value) / (max_value - min_value)

In [3]:
neg_x = tf.data.Dataset.zip(
    (
        tf.data.TFRecordDataset(SMALL_NEG_TFRECORD)
        .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE),
        tf.data.TFRecordDataset(BIG_NEG_TFRECORD)
        .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE),
    )
)
num_neg_samples = sum(1 for _ in neg_x)
print(f"Number of negative samples: {num_neg_samples}")
neg_x

Number of negative samples: 370


<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [4]:
neg_dataset = tf.data.Dataset.zip(
    (neg_x, tf.data.Dataset.from_tensor_slices(np.int8([[0]])).repeat(num_neg_samples))
)
assert sum(1 for _ in neg_dataset) == num_neg_samples
neg_dataset

<ZipDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [5]:
pos_x = tf.data.Dataset.zip(
    (
        tf.data.TFRecordDataset(SMALL_POS_TFRECORD)
        .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE),
        tf.data.TFRecordDataset(BIG_POS_TFRECORD)
        .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE),
    )
)
num_pos_samples = sum(1 for _ in pos_x)
print(f"Number of positive samples: {num_pos_samples}")
pos_x

Number of positive samples: 379


<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [6]:
pos_dataset = tf.data.Dataset.zip(
    (pos_x, tf.data.Dataset.from_tensor_slices(np.int8([[1]])).repeat(num_pos_samples))
)
assert sum(1 for _ in pos_dataset) == num_pos_samples
pos_dataset

<ZipDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [7]:
total_samples = num_neg_samples + num_pos_samples
dataset = neg_dataset.concatenate(pos_dataset).shuffle(
    buffer_size=total_samples, seed=SEED, reshuffle_each_iteration=False
)
dataset

<ShuffleDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [8]:
def ksplit_dataset(k, dataset, cardinality=None, seed=SEED):
    "Split a dataset into k datasets and drop the remaining elements"
    if not cardinality:
        cardinality = sum(1 for _ in dataset)
    assert 2 <= k <= cardinality
    dataset = dataset.shuffle(
        buffer_size=cardinality, reshuffle_each_iteration=False, seed=seed
    )
    split_size = cardinality // k
    splits = []
    for _ in range(k):
        splits.append(dataset.take(split_size))
        dataset = dataset.skip(split_size)
    return splits

In [9]:
def kfolds(k, dataset, cardinality=None, seed=SEED):
    "Generator of training / test set with k fold"
    if not cardinality:
        cardinality = sum(1 for _ in dataset)
    folds = ksplit_dataset(k, dataset, cardinality, seed)
    for i, test_dataset in enumerate(folds):
        train_folds = [f for j, f in enumerate(folds) if j != i]
        train_dataset = train_folds[0]
        for d in train_folds[1:]:
            train_dataset = train_dataset.concatenate(d)
        yield train_dataset, test_dataset

In [10]:
def train_model(
    model, train_dataset, val_dataset, patience, monitor_metric, model_fname, log_dir
):
    "Train the model and return the best model found with early stopping"
    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=1000,
        verbose=0,
        callbacks=[
            keras.callbacks.ModelCheckpoint(
                model_fname, monitor=monitor_metric, verbose=0, save_best_only=True
            ),
            keras.callbacks.EarlyStopping(
                monitor=monitor_metric,
                patience=patience,
            ),
            keras.callbacks.TensorBoard(
                log_dir=log_dir,
                histogram_freq=1,
                write_graph=False,
                profile_batch=0,
            ),
        ],
    )
    model = keras.models.load_model(model_fname)
    return model

In [16]:
# Hyperparameters
val_perc = 0.1
k = 10
patience = 30
batch_size = 16
learning_rate = 1e-5
dropout_rate = 0.6

In [17]:
SeluConv3D = partial(
    keras.layers.Conv3D,
    padding="same",
    activation="selu",
    kernel_initializer="lecun_normal",
    bias_initializer="zeros",
)

In [18]:
SeluDense = partial(
    keras.layers.Dense,
    activation="selu",
    kernel_initializer="lecun_normal",
    bias_initializer="zeros",
)

In [19]:
def build_and_compile_model():
    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="selu_conv3d_small_1",
    )(input_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_1")(x_small)
    x_small = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="selu_conv3d_small_2",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_2")(x_small)
    x_small = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="selu_conv3d_small_3",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_3")(x_small)
    x_small = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="selu_conv3d_small_4",
    )(x_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPooling3D((2, 2, 2), name="maxpool_big_0")(input_big)
    x_big = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="selu_conv3d_big_1",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_1")(x_big)
    x_big = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="selu_conv3d_big_2",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_2")(x_big)
    x_big = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="selu_conv3d_big_3",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_3")(x_big)
    x_big = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="selu_conv3d_big_4",
    )(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")
    x = SeluDense(128, name="selu_dense")(x)
    x = keras.layers.AlphaDropout(dropout_rate, name="alpha_dropout")(x)
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss="binary_crossentropy",
        metrics=[
            keras.metrics.TruePositives(name="tp"),
            keras.metrics.FalsePositives(name="fp"),
            keras.metrics.TrueNegatives(name="tn"),
            keras.metrics.FalseNegatives(name="fn"),
            keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )
    return cnn

In [20]:
from statistics import mean
from pprint import pprint

metrics = {"tp": [], "fp": [], "tn": [], "fn": [], "accuracy": []}
for train_val_dataset, test_dataset in kfolds(k, dataset, cardinality=total_samples):
    test_dataset = test_dataset.batch(1)
    train_dataset, val_dataset = train_test_split(train_val_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(1)
    train_dataset = (
        train_dataset.batch(batch_size)
        .cache()  # must be called before shuffle
        .shuffle(buffer_size=128, reshuffle_each_iteration=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    cnn = build_and_compile_model()
    start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    model_fname = f"models/baseline-{start_time}.h5"
    log_dir = f"logs/baseline-{start_time}"
    cnn = train_model(
        cnn, train_dataset, val_dataset, patience, "val_accuracy", model_fname, log_dir
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    pprint(test_metrics)
    for metric_name, metric_value in test_metrics.items():
        if metric_name in metrics:
            metrics[metric_name].append(metric_value)
mean_metrics = {
    metric_name: mean(metric_values) for metric_name, metric_values in metrics.items()
    if metric_name in metrics
}
pprint(mean_metrics)

{'accuracy': 0.7297297120094299,
 'fn': 12.0,
 'fp': 8.0,
 'loss': 0.7366518378257751,
 'tn': 25.0,
 'tp': 29.0}
{'accuracy': 0.7702702879905701,
 'fn': 5.0,
 'fp': 12.0,
 'loss': 0.5845815539360046,
 'tn': 27.0,
 'tp': 30.0}
{'accuracy': 0.7837837934494019,
 'fn': 5.0,
 'fp': 11.0,
 'loss': 0.7872048616409302,
 'tn': 27.0,
 'tp': 31.0}
{'accuracy': 0.7432432174682617,
 'fn': 9.0,
 'fp': 10.0,
 'loss': 0.5746363401412964,
 'tn': 27.0,
 'tp': 28.0}
{'accuracy': 0.7432432174682617,
 'fn': 9.0,
 'fp': 10.0,
 'loss': 1.1722667217254639,
 'tn': 26.0,
 'tp': 29.0}
{'accuracy': 0.7297297120094299,
 'fn': 8.0,
 'fp': 12.0,
 'loss': 0.8728437423706055,
 'tn': 24.0,
 'tp': 30.0}
{'accuracy': 0.7567567825317383,
 'fn': 4.0,
 'fp': 14.0,
 'loss': 1.4429597854614258,
 'tn': 28.0,
 'tp': 28.0}
{'accuracy': 0.7837837934494019,
 'fn': 9.0,
 'fp': 7.0,
 'loss': 0.5396794676780701,
 'tn': 24.0,
 'tp': 34.0}
{'accuracy': 0.6351351141929626,
 'fn': 12.0,
 'fp': 15.0,
 'loss': 1.4254175424575806,
 'tn': 21

In [None]:
patches, label = next(iter(test_dataset.skip(6)))
print(f"label: {label[0][0].numpy()}")
prediction = cnn(patches, training=False)
print(f"prediction: {prediction[0][0].numpy()}")
plot_volume_animation(patches[0][0, :])