In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm.notebook import tqdm

from data import tfrecord_dataset, train_test_split, kfolds
from layers import SeluConv3D, SeluDense
from plot import plot_volume_animation
from config import (
    CT_0_TFRECORD,
    CT_1_TFRECORD,
    CT_2_TFRECORD,
    CT_3_TFRECORD,
    CT_4_TFRECORD,
    SCAN_SHAPE,
)

In [None]:
SEED = 5
np.random.seed(SEED)

In [None]:
neg_x = tfrecord_dataset(CT_0_TFRECORD)
# neg_samples = sum(1 for _ in neg_x)
neg_samples = 254
print(f"{neg_samples = }")
neg_dataset = tf.data.Dataset.zip(
    (neg_x, tf.data.Dataset.from_tensor_slices(np.int8([[0]])).repeat(neg_samples))
)
assert sum(1 for _ in neg_dataset) == neg_samples

pos_x = tfrecord_dataset([CT_1_TFRECORD, CT_2_TFRECORD, CT_3_TFRECORD, CT_4_TFRECORD])
# pos_samples = sum(1 for _ in pos_x)
pos_samples = 856
print(f"{pos_samples = }")
pos_dataset = tf.data.Dataset.zip(
    (pos_x, tf.data.Dataset.from_tensor_slices(np.int8([[1]])).repeat(pos_samples))
)
assert sum(1 for _ in pos_dataset) == pos_samples

dataset = neg_dataset.concatenate(pos_dataset)
# samples = sum(1 for _ in dataset)
samples = neg_samples + pos_samples
assert sum(1 for _ in dataset) == samples
dataset

In [None]:
def build_model():
    model = keras.Sequential(
        [
            keras.layers.InputLayer(SCAN_SHAPE, name="input_layer"),
            SeluConv3D(filters=32, kernel_size=3, name="selu_conv3d_1"),
            keras.layers.MaxPool3D(2, name="maxpool3d_1"),
            SeluConv3D(filters=64, kernel_size=3, name="selu_conv3d_2"),
            keras.layers.MaxPool3D(2, name="maxpool3d_2"),
            SeluConv3D(filters=128, kernel_size=3, name="selu_conv3d_3"),
            keras.layers.MaxPool3D(2, name="maxpool3d_3"),
            SeluConv3D(filters=256, kernel_size=3, name="selu_conv3d_4"),
            keras.layers.MaxPool3D(2, name="maxpool3d_4"),
            keras.layers.Flatten(name="flatten"),
            keras.layers.Dense(1, activation="sigmoid", name="final_dense"),
        ],
        name="3d_cnn",
    )
    return model

m = build_model()
m.summary()

In [None]:
test_perc = 0.1
val_perc = 0.1
learning_rate = 1e-5
batch_size = 8
patience = 15
metrics = [
    keras.metrics.TruePositives(name="tp"),
    keras.metrics.FalsePositives(name="fp"),
    keras.metrics.TrueNegatives(name="tn"),
    keras.metrics.FalseNegatives(name="fn"),
    keras.metrics.Precision(name="precision"),
    keras.metrics.Recall(name="recall"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
    keras.metrics.AUC(name="auc"),
]

In [None]:
trainval_dataset, test_dataset = train_test_split(
    dataset, test_perc=test_perc, cardinality=samples
)
test_dataset = (
    test_dataset.cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset, val_dataset = train_test_split(trainval_dataset, test_perc=val_perc)
val_dataset = (
    val_dataset.cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset = (
    train_dataset.cache()  # must be called before shuffle
    .shuffle(buffer_size=1024, reshuffle_each_iteration=True)
    .batch(batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

cnn = build_model()
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.MeanSquaredError(),
    metrics=metrics,
)
cnn.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=1000,
    verbose=1,
    callbacks=[
        keras.callbacks.ModelCheckpoint(
            "models/covid-3d-cnn.h5",
            monitor="val_loss",
            save_best_only=True,
            verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=patience,
            restore_best_weights=True,
        ),
    ],
)

In [None]:
cnn.evaluate(test_dataset, return_dict=True)

In [5]:
k = 10
val_perc = 0.1
learning_rate = 1e-5
batch_size = 8
patience = 15
metrics = [
    keras.metrics.TruePositives(name="tp"),
    keras.metrics.FalsePositives(name="fp"),
    keras.metrics.TrueNegatives(name="tn"),
    keras.metrics.FalseNegatives(name="fn"),
    keras.metrics.Precision(name="precision"),
    keras.metrics.Recall(name="recall"),
    keras.metrics.BinaryAccuracy(name="accuracy"),
    keras.metrics.AUC(name="auc"),
]

In [6]:
mean_metrics = {
    f"{metric.name}": keras.metrics.Mean(name=f"mean_{metric.name}")
    for metric in metrics
}
for fold_id, (trainval_dataset, test_dataset) in tqdm(
    enumerate(kfolds(k, dataset, cardinality=samples, seed=SEED)), total=k
):
    print(f" {fold_id = } ".center(50, "="))

    test_dataset = (
        test_dataset.cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    )
    train_dataset, val_dataset = train_test_split(trainval_dataset, test_perc=val_perc)
    val_dataset = (
        val_dataset.cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    )
    train_dataset = (
        train_dataset.cache()  # must be called before shuffle
        .shuffle(buffer_size=256, reshuffle_each_iteration=True)
        .batch(batch_size)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

    cnn = build_model()
    cnn.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
        metrics=metrics,
    )
    cnn.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=1000,
        verbose=0,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor="val_loss",
                patience=patience,
                restore_best_weights=True,
            ),
        ],
    )
    test_metrics = cnn.evaluate(test_dataset, return_dict=True, verbose=0)
    for metric_name, metric_value in test_metrics.items():
        print(f"{metric_name}: {metric_value}")
        if metric_name in mean_metrics:
            mean_metrics[metric_name].update_state(metric_value)

print(" average ".center(50, "="))
for metric_name, metric_value in mean_metrics.items():
    print(f"{metric_name}: {metric_value.result()}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

loss: 0.15787240862846375
tp: 76.0
fp: 18.0
tn: 12.0
fn: 5.0
precision: 0.8085106611251831
recall: 0.9382715821266174
accuracy: 0.792792797088623
auc: 0.7814815044403076
loss: 0.1861683428287506
tp: 70.0
fp: 23.0
tn: 12.0
fn: 6.0
precision: 0.7526881694793701
recall: 0.9210526347160339
accuracy: 0.7387387156486511
auc: 0.7223684191703796
loss: 0.15484091639518738
tp: 84.0
fp: 21.0
tn: 4.0
fn: 2.0
precision: 0.800000011920929
recall: 0.9767441749572754
accuracy: 0.792792797088623
auc: 0.7609301805496216
loss: 0.17942418158054352
tp: 74.0
fp: 20.0
tn: 8.0
fn: 9.0
precision: 0.7872340679168701
recall: 0.891566276550293
accuracy: 0.7387387156486511
auc: 0.7018071413040161
loss: 0.1498459428548813
tp: 84.0
fp: 15.0
tn: 5.0
fn: 7.0
precision: 0.8484848737716675
recall: 0.9230769276618958
accuracy: 0.8018018007278442
auc: 0.6733516454696655
loss: 0.15912793576717377
tp: 84.0
fp: 20.0
tn: 3.0
fn: 4.0
precision: 0.807692289352417
recall: 0.9545454382896423
accuracy: 0.7837837934494019
auc: 0.69