In [1]:
from collections import defaultdict

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from data import classification_dataset, train_test_split, kfolds
from layers import SeluConv3D, SeluDense
from plot import plot_slice, plot_volume_animation
from config import (
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [2]:
SEED = 5
np.random.seed(SEED)

In [3]:
lidc_dataset, lidc_samples = classification_dataset(
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_BIG_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    LIDC_BIG_POS_TFRECORD,
    return_size=True,
)
print(f"{lidc_samples = }")
lidc_dataset

lidc_samples = 754


<ShuffleDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [4]:
def build_3d_cnn():
    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="small_selu_conv3d_1",
    )(input_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_1")(x_small)
    x_small = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="small_selu_conv3d_2",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_2")(x_small)
    x_small = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="small_selu_conv3d_3",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 2, 2), name="small_maxpool_3")(x_small)
    x_small = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="small_selu_conv3d_4",
    )(x_small)
    x_small = keras.layers.MaxPool3D((1, 1, 2), name="small_maxpool_4")(x_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPool3D((2, 2, 2), name="big_maxpool_0")(input_big)
    x_big = SeluConv3D(
        filters=32,
        kernel_size=3,
        name="big_selu_conv3d_1",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_1")(x_big)
    x_big = SeluConv3D(
        filters=64,
        kernel_size=3,
        name="big_selu_conv3d_2",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_2")(x_big)
    x_big = SeluConv3D(
        filters=128,
        kernel_size=3,
        name="big_selu_conv3d_3",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 2, 2), name="big_maxpool_3")(x_big)
    x_big = SeluConv3D(
        filters=256,
        kernel_size=3,
        name="big_selu_conv3d_4",
    )(x_big)
    x_big = keras.layers.MaxPool3D((1, 1, 2), name="big_maxpool_4")(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn_3d = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    return cnn_3d

In [5]:
learning_rate = 1e-5
val_perc = 0.1
patience = 10
batch_size = 16
metrics = [
    keras.metrics.AUC(name="auc"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
]

In [6]:
train_dataset, val_dataset = train_test_split(
    lidc_dataset, test_perc=val_perc, seed=SEED
)
val_dataset = val_dataset.batch(batch_size)
train_dataset = (
    train_dataset.cache()  # must be called before shuffle
    .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
    .batch(batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [7]:
cnn = build_3d_cnn()
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.MeanSquaredError(),
    metrics=metrics,
)
cnn.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=1000,
    verbose=1,
    callbacks=[
        keras.callbacks.ModelCheckpoint(
            "models/lidc-3d-cnn.h5",
            monitor="val_loss",
            save_best_only=True,
            verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=patience,
            restore_best_weights=True,
        ),
    ],
)

Epoch 1/1000
     43/Unknown - 1s 16ms/step - loss: 0.2257 - auc: 0.6880 - accuracy: 0.6244
Epoch 00001: val_loss improved from inf to 0.20687, saving model to models/lidc-3d-cnn.h5
Epoch 2/1000
Epoch 00002: val_loss improved from 0.20687 to 0.17326, saving model to models/lidc-3d-cnn.h5
Epoch 3/1000
Epoch 00003: val_loss improved from 0.17326 to 0.16607, saving model to models/lidc-3d-cnn.h5
Epoch 4/1000
Epoch 00004: val_loss improved from 0.16607 to 0.16244, saving model to models/lidc-3d-cnn.h5
Epoch 5/1000
Epoch 00005: val_loss improved from 0.16244 to 0.15393, saving model to models/lidc-3d-cnn.h5
Epoch 6/1000
Epoch 00006: val_loss did not improve from 0.15393
Epoch 7/1000
Epoch 00007: val_loss improved from 0.15393 to 0.14808, saving model to models/lidc-3d-cnn.h5
Epoch 8/1000
Epoch 00008: val_loss did not improve from 0.14808
Epoch 9/1000
Epoch 00009: val_loss did not improve from 0.14808
Epoch 10/1000
Epoch 00010: val_loss improved from 0.14808 to 0.14208, saving model to model

Epoch 25/1000
Epoch 00025: val_loss did not improve from 0.13700
Epoch 26/1000
Epoch 00026: val_loss improved from 0.13700 to 0.13521, saving model to models/lidc-3d-cnn.h5
Epoch 27/1000
Epoch 00027: val_loss improved from 0.13521 to 0.13392, saving model to models/lidc-3d-cnn.h5
Epoch 28/1000
Epoch 00028: val_loss did not improve from 0.13392
Epoch 29/1000
Epoch 00029: val_loss did not improve from 0.13392
Epoch 30/1000
Epoch 00030: val_loss did not improve from 0.13392
Epoch 31/1000
Epoch 00031: val_loss did not improve from 0.13392
Epoch 32/1000
Epoch 00032: val_loss did not improve from 0.13392
Epoch 33/1000
Epoch 00033: val_loss improved from 0.13392 to 0.13193, saving model to models/lidc-3d-cnn.h5
Epoch 34/1000
Epoch 00034: val_loss improved from 0.13193 to 0.13165, saving model to models/lidc-3d-cnn.h5
Epoch 35/1000
Epoch 00035: val_loss did not improve from 0.13165
Epoch 36/1000
Epoch 00036: val_loss did not improve from 0.13165
Epoch 37/1000
Epoch 00037: val_loss did not impr

Epoch 50/1000
Epoch 00050: val_loss did not improve from 0.13046
Epoch 51/1000
Epoch 00051: val_loss improved from 0.13046 to 0.12993, saving model to models/lidc-3d-cnn.h5
Epoch 52/1000
Epoch 00052: val_loss improved from 0.12993 to 0.12845, saving model to models/lidc-3d-cnn.h5
Epoch 53/1000
Epoch 00053: val_loss did not improve from 0.12845
Epoch 54/1000
Epoch 00054: val_loss did not improve from 0.12845
Epoch 55/1000
Epoch 00055: val_loss did not improve from 0.12845
Epoch 56/1000
Epoch 00056: val_loss did not improve from 0.12845
Epoch 57/1000
Epoch 00057: val_loss did not improve from 0.12845
Epoch 58/1000
Epoch 00058: val_loss did not improve from 0.12845
Epoch 59/1000
Epoch 00059: val_loss improved from 0.12845 to 0.12817, saving model to models/lidc-3d-cnn.h5
Epoch 60/1000
Epoch 00060: val_loss did not improve from 0.12817
Epoch 61/1000
Epoch 00061: val_loss did not improve from 0.12817
Epoch 62/1000
Epoch 00062: val_loss did not improve from 0.12817
Epoch 63/1000
Epoch 00063:

<tensorflow.python.keras.callbacks.History at 0x7fb0b01c4c10>

In [8]:
cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
cnn.evaluate(val_dataset, return_dict=True)



{'loss': 0.12520287930965424,
 'auc': 0.9074074029922485,
 'accuracy': 0.8666666746139526}

In [9]:
spie_dataset, spie_samples = classification_dataset(
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    return_size=True,
)
print(f"{spie_samples = }")
spie_dataset

spie_samples = 73


<ShuffleDataset shapes: (((None, None, None, None), (None, None, None, None)), (1,)), types: ((tf.float32, tf.float32), tf.int8)>

In [10]:
cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
cnn.evaluate(spie_dataset.batch(1), return_dict=True)



{'loss': 0.34322842955589294,
 'auc': 0.6981981992721558,
 'accuracy': 0.5616438388824463}

In [11]:
def build_pretrained_3d_cnn(freeze_conv_layers):
    pretrained_3d_cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
    if freeze_conv_layers:
        for layer in pretrained_3d_cnn.layers:
            if "conv" in layer.name:
                layer.trainable = False
    return pretrained_3d_cnn

In [14]:
k = 3
learning_rate = 1e-5
val_perc = 0.1
patience = 10
num_epochs = 1000
batch_size = 8
metrics = [
    keras.metrics.AUC(name="auc"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
]

In [15]:
lidc_mean_metrics = {
    f"{metric.name}": keras.metrics.Mean(name=f"mean_{metric.name}")
    for metric in metrics
}
wo_pt_mean_metrics = {
    f"{metric.name}": keras.metrics.Mean(name=f"mean_{metric.name}")
    for metric in metrics
}
w_pt_mean_metrics = {
    f"{metric.name}": keras.metrics.Mean(name=f"mean_{metric.name}")
    for metric in metrics
}
w_pt_conv_mean_metrics = {
    f"{metric.name}": keras.metrics.Mean(name=f"mean_{metric.name}")
    for metric in metrics
}
fold_id = 0
for trainval_dataset, test_dataset in tqdm(
    kfolds(k, spie_dataset, cardinality=spie_samples, seed=SEED), total=k
):
    print(f" {fold_id = } ".center(50, "="))

    test_dataset = test_dataset.batch(batch_size)
    train_dataset, val_dataset = train_test_split(trainval_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(batch_size)
    train_dataset = (
        train_dataset.cache()  # must be called before shuffle
        .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
        .batch(batch_size)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    print(f"Train size: {sum(1 for _ in train_dataset.unbatch())}")
    print(f"Validation size: {sum(1 for _ in val_dataset.unbatch())}")
    print(f"Test size: {sum(1 for _ in test_dataset.unbatch())}")
    print()

    cnn = keras.models.load_model("models/lidc-3d-cnn.h5")
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    print("LIDC training only: ")
    for metric_name, metric_value in test_metrics.items():
        print(f"{metric_name}: {metric_value}")
        if metric_name in lidc_mean_metrics:
            lidc_mean_metrics[metric_name].update_state(metric_value)
    print("")

    cnn = build_3d_cnn()
    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
        metrics=metrics,
    )
    cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            ),
        ],
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    print("Without pretraining: ")
    for metric_name, metric_value in test_metrics.items():
        print(f"{metric_name}: {metric_value}")
        if metric_name in wo_pt_mean_metrics:
            wo_pt_mean_metrics[metric_name].update_state(metric_value)
    print("")

    cnn = build_pretrained_3d_cnn(True)
    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
        metrics=metrics,
    )
    cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            ),
        ],
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    print("With pretraining (W/O conv): ")
    for metric_name, metric_value in test_metrics.items():
        print(f"{metric_name}: {metric_value}")
        if metric_name in w_pt_mean_metrics:
            w_pt_mean_metrics[metric_name].update_state(metric_value)
    print("")

    cnn = build_pretrained_3d_cnn(False)
    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
        metrics=metrics,
    )
    cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            ),
        ],
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    print("With pretraining (W/ conv): ")
    for metric_name, metric_value in test_metrics.items():
        print(f"{metric_name}: {metric_value}")
        if metric_name in w_pt_conv_mean_metrics:
            w_pt_conv_mean_metrics[metric_name].update_state(metric_value)

    fold_id += 1

print(" average ".center(50, "="))
print("LIDC training only: ")
for metric_name, metric_value in lidc_mean_metrics.items():
    print(f"{metric_name}: {metric_value.result()}")
print("")
print("Without pretraining: ")
for metric_name, metric_value in wo_pt_mean_metrics.items():
    print(f"{metric_name}: {metric_value.result()}")
print("")

print("With pretraining (W/O conv): ")
for metric_name, metric_value in w_pt_mean_metrics.items():
    print(f"{metric_name}: {metric_value.result()}")
print("")

print("With pretraining (W/ conv): ")
for metric_name, metric_value in w_pt_conv_mean_metrics.items():
    print(f"{metric_name}: {metric_value.result()}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

Train size: 45
Validation size: 4
Test size: 24

LIDC training only: 
loss: 0.39898553490638733
auc: 0.6259258985519409
accuracy: 0.4583333432674408

Without pretraining: 
loss: 0.2588021457195282
auc: 0.7370370030403137
accuracy: 0.5833333134651184

With pretraining (W/O conv): 
loss: 0.38833796977996826
auc: 0.6296296715736389
accuracy: 0.5

With pretraining (W/ conv): 
loss: 0.28150704503059387
auc: 0.6074073910713196
accuracy: 0.7083333134651184
Train size: 45
Validation size: 4
Test size: 24

LIDC training only: 
loss: 0.29896751046180725
auc: 0.735714316368103
accuracy: 0.625

Without pretraining: 
loss: 0.21459273993968964
auc: 0.7321428060531616
accuracy: 0.625

With pretraining (W/O conv): 
loss: 0.2157258242368698
auc: 0.7250000238418579
accuracy: 0.625

With pretraining (W/ conv): 
loss: 0.21164774894714355
auc: 0.7285714149475098
accuracy: 0.625
Train size: 44
Validation size: 4
Test size: 25

LIDC training only: 
loss: 0.3321920335292816
auc: 0.7083333134651184
accuracy: 0