In [1]:
import datetime
from statistics import mean
from pprint import pprint

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from data import kfolds, train_test_split, classification_dataset
from train import train_model
from layers import SeluConv3D, SeluDense
from plot import plot_slice, plot_volume_animation
from config import (
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [18]:
# Hyperparameters
val_perc = 0.1
k = 3
patience = 15
extra_epochs = 10
batch_size = 16
learning_rate = 1e-5
dropout_rate = 0.6

In [13]:
spie_dataset, total_samples = classification_dataset(
    SPIE_SMALL_NEG_TFRECORD,
    SPIE_BIG_NEG_TFRECORD,
    SPIE_SMALL_POS_TFRECORD,
    SPIE_BIG_POS_TFRECORD,
    return_size=True,
)
total_samples

73

In [14]:
def build_model():
    small_encoder = keras.models.load_model("models/autoencoder-lidc.h5").get_layer(
        "encoder"
    )
    small_encoder._name = "small_encoder"
    small_encoder.trainable = False

    input_small = keras.Input(SMALL_PATCH_SHAPE, name="input_small")
    x_small = small_encoder(input_small)
    x_small = keras.layers.Flatten(name="flatten_small")(x_small)

    big_encoder = keras.models.load_model("models/autoencoder-lidc.h5").get_layer(
        "encoder"
    )
    big_encoder._name = "big_encoder"
    big_encoder.trainable = False

    input_big = keras.Input(BIG_PATCH_SHAPE, name="input_big")
    x_big = keras.layers.MaxPooling3D((2, 2, 2), name="big_maxpool_0")(input_big)
    x_big = big_encoder(x_big)
    x_big = keras.layers.Flatten(name="flatten_big")(x_big)

    x = keras.layers.concatenate([x_small, x_big], name="concatenate")

    x = SeluDense(128, name="selu_dense")(x)
    x = keras.layers.AlphaDropout(dropout_rate, name="alpha_dropout")(x)
    x = keras.layers.Dense(1, activation="sigmoid", name="final_dense")(x)

    cnn_3d = keras.Model(inputs=[input_small, input_big], outputs=x, name="3dcnn")

    return cnn_3d

In [15]:
cnn_3d = keras.models.load_model("models/baseline-lidc.h5")
cnn_3d.evaluate(spie_dataset.batch(1), return_dict=True)



{'loss': 2.647796154022217,
 'auc': 0.632882833480835,
 'accuracy': 0.5890411138534546}

In [16]:
metrics = [
    keras.metrics.AUC(name="auc"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
]

In [None]:
mean_metrics = {
    metric.name: keras.metrics.Mean(f"{metric.name}_mean", dtype=tf.float32)
    for metric in metrics
}
fold_id = 0
for train_val_dataset, test_dataset in tqdm(
    kfolds(k, spie_dataset, cardinality=total_samples), total=k
):
    test_dataset = test_dataset.batch(1)
    train_dataset, val_dataset = train_test_split(train_val_dataset, test_perc=val_perc)
    val_dataset = val_dataset.batch(1)
    train_dataset = (
        train_dataset.batch(batch_size)
        .cache()  # must be called before shuffle
        .shuffle(buffer_size=64, reshuffle_each_iteration=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    cnn_3d = build_model()
    cnn_3d.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss="binary_crossentropy",
        metrics=metrics,
    )

    log_dir = f"logs/pretrained-lidc-{fold_id}"
    model_fname = f"models/pretrained-lidc-{fold_id}.h5"
    cnn_3d = train_model(
        cnn_3d,
        train_dataset,
        val_dataset,
        patience,
        "val_accuracy",
        model_fname,
        log_dir,
    )

    test_metrics = cnn_3d.evaluate(test_dataset, return_dict=True, verbose=0)

    print(f" {fold_id=} ".center(40, "="))
    for metric_name, metric_value in test_metrics.items():
        if metric_name in mean_metrics:
            print(f"{metric_name}: {metric_value}")
            mean_metrics[metric_name].update_state(metric_value)

    fold_id += 1

print(" mean ".center(40, "="))
for metric_name, metric_value in mean_metrics.items():
    print(f"{metric_name}: {metric_value.result().numpy()}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

auc: 0.7892857193946838
accuracy: 0.5
auc: 0.7142857313156128
accuracy: 0.2916666567325592


In [None]:
patches, label = next(iter(test_dataset.skip(6)))
print(f"label: {label[0][0].numpy()}")
prediction = cnn(patches, training=False)
print(f"prediction: {prediction[0][0].numpy()}")
plot_volume_animation(patches[0][0, :])