In [1]:
import datetime
from statistics import mean
from pprint import pprint

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from data import kfolds, train_test_split, classification_dataset
from train import train_model
from layers import SeluConv3D, SeluDense
from plot import plot_slice, plot_volume_animation
from config import (
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [6]:
# Hyperparameters
val_perc = 0.1
k = 10
patience = 30
batch_size = 16
learning_rate = 1e-5
dropout_rate = 0.6

In [7]:
dataset, total_samples = classification_dataset(
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    return_size=True,
)
total_samples

754

In [8]:
def build_and_compile_model():
    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="small_selu_conv3d_1",
    )(input_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="small_maxpool_1")(x_small)
    x_small = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="small_selu_conv3d_2",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="small_maxpool_2")(x_small)
    x_small = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="small_selu_conv3d_3",
    )(x_small)
    x_small = keras.layers.MaxPooling3D((1, 2, 2), name="small_maxpool_3")(x_small)
    x_small = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="small_selu_conv3d_4",
    )(x_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPooling3D((2, 2, 2), name="big_maxpool_0")(input_big)
    x_big = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="big_selu_conv3d_1",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="big_maxpool_1")(x_big)
    x_big = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="big_selu_conv3d_2",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="big_maxpool_2")(x_big)
    x_big = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="big_selu_conv3d_3",
    )(x_big)
    x_big = keras.layers.MaxPooling3D((1, 2, 2), name="big_maxpool_3")(x_big)
    x_big = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="big_selu_conv3d_4",
    )(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")
    x = SeluDense(128, name="selu_dense")(x)
    x = keras.layers.AlphaDropout(dropout_rate, name="alpha_dropout")(x)
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss="binary_crossentropy",
        metrics=[
            keras.metrics.TruePositives(name="tp"),
            keras.metrics.FalsePositives(name="fp"),
            keras.metrics.TrueNegatives(name="tn"),
            keras.metrics.FalseNegatives(name="fn"),
            keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )
    return cnn

In [9]:
metrics = {"tp": [], "fp": [], "tn": [], "fn": [], "accuracy": []}
sess_id = 0
for train_val_dataset, test_dataset in tqdm(
    kfolds(k, dataset, cardinality=total_samples), total=k
):
    test_dataset = test_dataset.batch(1)
    train_dataset, val_dataset = train_test_split(train_val_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(1)
    train_dataset = (
        train_dataset.batch(batch_size)
        .cache()  # must be called before shuffle
        .shuffle(buffer_size=128, reshuffle_each_iteration=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    cnn = build_and_compile_model()
    model_fname = f"models/baseline-lidc-{sess_id}.h5"
    log_dir = f"logs/baseline-lidc-{sess_id}"
    cnn = train_model(
        cnn,
        train_dataset,
        val_dataset,
        patience,
        "val_accuracy",
        model_fname,
        log_dir,
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    for metric_name, metric_value in test_metrics.items():
        if metric_name in metrics:
            metrics[metric_name].append(metric_value)
    print(f"Last accuracy: {metrics['accuracy'][-1]}")
    sess_id += 1
mean_metrics = {
    metric_name: mean(metric_values)
    for metric_name, metric_values in metrics.items()
    if metric_name in metrics
}
print("==============================================")
pprint(mean_metrics)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

Last accuracy: 0.8399999737739563
Last accuracy: 0.7599999904632568
Last accuracy: 0.7200000286102295
Last accuracy: 0.7866666913032532
Last accuracy: 0.7599999904632568
Last accuracy: 0.6800000071525574
Last accuracy: 0.7599999904632568
Last accuracy: 0.653333306312561
Last accuracy: 0.7599999904632568
Last accuracy: 0.8266666531562805

{'accuracy': 0.7546666622161865, 'fn': 11.0, 'fp': 7.4, 'tn': 29.9, 'tp': 26.7}


In [None]:
patches, label = next(iter(test_dataset.skip(6)))
print(f"label: {label[0][0].numpy()}")
prediction = cnn(patches, training=False)
print(f"prediction: {prediction[0][0].numpy()}")
plot_volume_animation(patches[0][0, :])