In [9]:
import datetime
from statistics import mean
from pprint import pprint

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

from data import tfrecord_dataset, kfolds, train_test_split
from train import train_model
from layers import SeluConv3D, SeluDense
from plot import plot_slice, plot_volume_animation
from config import (
    SMALL_NEG_TFRECORD,
    SMALL_POS_TFRECORD,
    BIG_NEG_TFRECORD,
    BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
    SEED,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [2]:
neg_x = tf.data.Dataset.zip(
    (
        tfrecord_dataset(SMALL_NEG_TFRECORD),
        tfrecord_dataset(BIG_NEG_TFRECORD),
    )
)
num_neg_samples = sum(1 for _ in neg_x)
print(f"Number of negative samples: {num_neg_samples}")
neg_x

Number of negative samples: 370


<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [3]:
neg_dataset = tf.data.Dataset.zip(
    (neg_x, tf.data.Dataset.from_tensor_slices(np.int8([[0]])).repeat(num_neg_samples))
)
assert sum(1 for _ in neg_dataset) == num_neg_samples
neg_dataset

<ZipDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [4]:
pos_x = tf.data.Dataset.zip(
    (
        tfrecord_dataset(SMALL_POS_TFRECORD),
        tfrecord_dataset(BIG_POS_TFRECORD),
    )
)
num_pos_samples = sum(1 for _ in pos_x)
print(f"Number of positive samples: {num_pos_samples}")
pos_x

Number of positive samples: 379


<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [5]:
pos_dataset = tf.data.Dataset.zip(
    (pos_x, tf.data.Dataset.from_tensor_slices(np.int8([[1]])).repeat(num_pos_samples))
)
assert sum(1 for _ in pos_dataset) == num_pos_samples
pos_dataset

<ZipDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [6]:
total_samples = num_neg_samples + num_pos_samples
dataset = neg_dataset.concatenate(pos_dataset).shuffle(
    buffer_size=total_samples, seed=SEED, reshuffle_each_iteration=False
)
dataset

<ShuffleDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [7]:
# Hyperparameters
val_perc = 0.1
k = 10
patience = 30
batch_size = 16
learning_rate = 1e-5
dropout_rate = 0.6

In [8]:
def build_and_compile_model():
    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="selu_conv3d_small_1",
    )(input_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_1")(x_small)
    x_small = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="selu_conv3d_small_2",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_2")(x_small)
    x_small = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="selu_conv3d_small_3",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_small_3")(x_small)
    x_small = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="selu_conv3d_small_4",
    )(x_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPooling3D((2, 2, 2), name="maxpool_big_0")(input_big)
    x_big = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="selu_conv3d_big_1",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_1")(x_big)
    x_big = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="selu_conv3d_big_2",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_2")(x_big)
    x_big = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="selu_conv3d_big_3",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_big_3")(x_big)
    x_big = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="selu_conv3d_big_4",
    )(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")
    x = SeluDense(128, name="selu_dense")(x)
    x = keras.layers.AlphaDropout(dropout_rate, name="alpha_dropout")(x)
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss="binary_crossentropy",
        metrics=[
            keras.metrics.TruePositives(name="tp"),
            keras.metrics.FalsePositives(name="fp"),
            keras.metrics.TrueNegatives(name="tn"),
            keras.metrics.FalseNegatives(name="fn"),
            keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )
    return cnn

In [10]:
metrics = {"tp": [], "fp": [], "tn": [], "fn": [], "accuracy": []}
for train_val_dataset, test_dataset in kfolds(k, dataset, cardinality=total_samples):
    test_dataset = test_dataset.batch(1)
    train_dataset, val_dataset = train_test_split(train_val_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(1)
    train_dataset = (
        train_dataset.batch(batch_size)
        .cache()  # must be called before shuffle
        .shuffle(buffer_size=128, reshuffle_each_iteration=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    cnn = build_and_compile_model()
    start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    model_fname = f"models/baseline-{start_time}.h5"
    log_dir = f"logs/baseline-{start_time}"
    cnn = train_model(
        cnn, train_dataset, val_dataset, patience, "val_accuracy", model_fname, log_dir,
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    pprint(test_metrics)
    for metric_name, metric_value in test_metrics.items():
        if metric_name in metrics:
            metrics[metric_name].append(metric_value)
mean_metrics = {
    metric_name: mean(metric_values)
    for metric_name, metric_values in metrics.items()
    if metric_name in metrics
}
pprint(mean_metrics)

{'accuracy': 0.662162184715271,
 'fn': 9.0,
 'fp': 16.0,
 'loss': 0.5848488807678223,
 'tn': 17.0,
 'tp': 32.0}
{'accuracy': 0.7297297120094299,
 'fn': 17.0,
 'fp': 3.0,
 'loss': 0.8103933334350586,
 'tn': 36.0,
 'tp': 18.0}
{'accuracy': 0.5405405163764954,
 'fn': 6.0,
 'fp': 28.0,
 'loss': 0.9122430086135864,
 'tn': 10.0,
 'tp': 30.0}
{'accuracy': 0.6351351141929626,
 'fn': 23.0,
 'fp': 4.0,
 'loss': 1.1534898281097412,
 'tn': 33.0,
 'tp': 14.0}
{'accuracy': 0.6081081032752991,
 'fn': 1.0,
 'fp': 28.0,
 'loss': 0.9350642561912537,
 'tn': 8.0,
 'tp': 37.0}
{'accuracy': 0.5945945978164673,
 'fn': 8.0,
 'fp': 22.0,
 'loss': 0.7179050445556641,
 'tn': 14.0,
 'tp': 30.0}
{'accuracy': 0.5135135054588318,
 'fn': 2.0,
 'fp': 34.0,
 'loss': 1.3297961950302124,
 'tn': 8.0,
 'tp': 30.0}
{'accuracy': 0.5675675868988037,
 'fn': 2.0,
 'fp': 30.0,
 'loss': 0.8512935042381287,
 'tn': 1.0,
 'tp': 41.0}
{'accuracy': 0.5405405163764954,
 'fn': 33.0,
 'fp': 1.0,
 'loss': 1.781236171722412,
 'tn': 35.0,
 

In [None]:
patches, label = next(iter(test_dataset.skip(6)))
print(f"label: {label[0][0].numpy()}")
prediction = cnn(patches, training=False)
print(f"prediction: {prediction[0][0].numpy()}")
plot_volume_animation(patches[0][0, :])