In [2]:
from collections import defaultdict
import time
from statistics import mean

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from data import classification_dataset, train_test_split, kfolds
from layers import SeluConv3D, SeluDense
from callbacks import TimeEpoch
from plot import plot_slice, plot_volume_animation, plot_loss_history
from config import (
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]
matplotlib.rcParams.update({"font.size": 18})

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [3]:
lidc_dataset, lidc_samples = classification_dataset(
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    return_size=True,
)
print(f"{lidc_samples = }")
lidc_dataset

lidc_samples = 754


<ShuffleDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [None]:
def build_3d_cnn():
    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="small_selu_conv3d_1",
    )(input_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_1")(x_small)
    x_small = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="small_selu_conv3d_2",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_2")(x_small)
    x_small = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="small_selu_conv3d_3",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_3")(x_small)
    x_small = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="small_selu_conv3d_4",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_4")(x_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPool3D((2, 2, 2), name="big_maxpool_0")(input_big)
    x_big = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="big_selu_conv3d_1",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_1")(x_big)
    x_big = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="big_selu_conv3d_2",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_2")(x_big)
    x_big = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="big_selu_conv3d_3",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_3")(x_big)
    x_big = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="big_selu_conv3d_4",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_4")(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn_3d = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    return cnn_3d

In [None]:
learning_rate = 1e-5
val_perc = 0.2
patience = 20
batch_size = 16
metrics = [
    keras.metrics.TruePositives(name="tp"),
    keras.metrics.FalsePositives(name="fp"),
    keras.metrics.TrueNegatives(name="tn"),
    keras.metrics.FalseNegatives(name="fn"),
    keras.metrics.Precision(name="precision"),
    keras.metrics.Recall(name="recall"),
    keras.metrics.AUC(name="auc"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
]

In [None]:
train_dataset, val_dataset = train_test_split(lidc_dataset, test_perc=val_perc)
val_dataset = val_dataset.batch(batch_size)
train_dataset = (
    train_dataset.cache()  # must be called before shuffle
    .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
    .batch(batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [None]:
cnn = build_3d_cnn()
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=metrics,
)
time_callback = TimeEpoch()
history = cnn.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=1000,
    verbose=1,
    callbacks=[
        time_callback,
        keras.callbacks.ModelCheckpoint(
            "models/lidc-3d-cnn.h5",
            monitor="val_loss",
            save_best_only=True,
            verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=patience,
            restore_best_weights=True,
        ),
    ],
)

In [None]:
cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
cnn.evaluate(val_dataset, return_dict=True)

In [None]:
plot_loss_history(history)
# plt.savefig("lidc-loss.pdf")

In [None]:
f"Mean time per epoch: {np.mean(time_callback.times):.2f}s"

In [None]:
spie_dataset, spie_samples = classification_dataset(
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    return_size=True,
)
print(f"{spie_samples = }")
spie_dataset

In [None]:
cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
cnn.evaluate(spie_dataset.batch(1), return_dict=True)

In [None]:
def build_pretrained_3d_cnn(freeze_conv_layers=True):
    pretrained_3d_cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
    if freeze_conv_layers:
        for layer in pretrained_3d_cnn.layers:
            if "conv" in layer.name:
                layer.trainable = False
    return pretrained_3d_cnn

In [None]:
# k = spie_samples  # LOOCV
k = 3
val_perc = 0.1
learning_rate = 1e-5
batch_size = 8
patience = 30
num_epochs = 2000

In [None]:
y_true = []

wo_pt_histories = []
wo_pt_predictions = []
wo_pt_aucs = []
wo_pt_accs = []

w_pt_histories = []
w_pt_predictions = []
w_pt_aucs = []
w_pt_accs = []

for fold_id, (train_dataset, test_dataset) in tqdm(
    enumerate(kfolds(k, spie_dataset, cardinality=spie_samples)), total=k
):
    # print(f" {fold_id = } ".center(50, "="))
    for _, y in test_dataset.as_numpy_iterator():
        y_true.append(y[0])
    test_dataset = test_dataset.batch(1)
    train_dataset, val_dataset = train_test_split(train_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(batch_size)
    train_dataset = (
        train_dataset.cache()  # must be called before shuffle
        .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
        .batch(batch_size)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    cnn = build_3d_cnn()
    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.BinaryCrossentropy(),
        metrics=[
            keras.metrics.AUC(name="auc", num_thresholds=1000),
            keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )
    history = cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            )
        ],
    )
    wo_pt_histories.append(history)
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    wo_pt_aucs.append(round(test_metrics["auc"], 2))
    wo_pt_accs.append(round(test_metrics["accuracy"], 2))
    for test_x, _ in test_dataset.as_numpy_iterator():
        pred_y = cnn(test_x, training=False)
        wo_pt_predictions.append(pred_y.numpy()[0][0])

    cnn = build_pretrained_3d_cnn()
    cnn.compile(
        optimizer=keras.optimizers.Adam(1e-5),
        loss=keras.losses.BinaryCrossentropy(),
        metrics=[
            keras.metrics.AUC(name="auc", num_thresholds=1000),
            keras.metrics.BinaryAccuracy(name="accuracy"),
        ],
    )
    history = cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=100,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            )
        ],
    )
    w_pt_histories.append(history)
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    w_pt_aucs.append(round(test_metrics["auc"], 2))
    w_pt_accs.append(round(test_metrics["accuracy"], 2))
    for test_x, _ in test_dataset.as_numpy_iterator():
        pred_y = cnn(test_x, training=False)
        w_pt_predictions.append(pred_y.numpy()[0][0])

In [None]:
print(f"{wo_pt_aucs = }")
print(f"{w_pt_aucs = }")
delta_auc = [
    round(w_pt_auc - wo_pt_auc, 2) for w_pt_auc, wo_pt_auc in zip(w_pt_aucs, wo_pt_aucs)
]
print(f"{delta_auc = }")
print(f"W/O pre-training mean AUC: {mean(wo_pt_aucs)}")
print(f"W/ pre-training mean AUC: {mean(w_pt_aucs)}")
print("W/O pre-training confusion matrix:")
print(f"{tf.math.confusion_matrix(y_true, [round(x, 0) for x in wo_pt_predictions])}")
print("")
print("W/ pre-training confusion matrix:")
print(f"{tf.math.confusion_matrix(y_true, [round(x, 0) for x in w_pt_predictions])}")
print("")

In [None]:
index = 1
plt.plot(
    wo_pt_histories[index].history["loss"], "--", label="w/o pre-training - train loss"
)
plt.plot(
    wo_pt_histories[index].history["val_loss"], label="w/o pre-training - val loss"
)
plt.plot(
    w_pt_histories[index].history["loss"],
    "--",
    label="w/ pre-training - train loss",
)
plt.plot(
    w_pt_histories[index].history["val_loss"],
    label="w/ pre-training - val loss",
)
plt.legend()
# plt.savefig(f"fold_{index}.pdf")
# plot_loss_history(w_pt_conv_histories[0])