In [None]:
import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras

from model import build_encoder, build_autoencoder
from data import example_to_tensor, normalize, add_channel_axis, train_test_split
from train import train_func
from utils import duplicate_iterator, plot_slice, plot_animated_volume
from config import allocate_gpu_memory_only_when_needed, data_root_dir, seed

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [None]:
allocate_gpu_memory_only_when_needed(True)
input_shape = (24, 128, 128, 1)
epochs = 1000
patience = 20
batch_size = 4
learning_rate = 0.0001
unsupervised_val_perc = 0.2
val_perc = 0.1  # percentage from the already splitted training test
test_perc = 0.1

In [None]:
neg_tfrecord_fnames = [str(p) for p in Path(data_root_dir).glob("CT-0-0.25/*.tfrecord")]
neg_x = (
    tf.data.TFRecordDataset(neg_tfrecord_fnames)
    .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(add_channel_axis, num_parallel_calls=tf.data.experimental.AUTOTUNE)
)
neg_x

In [None]:
pos_tfrecord_fnames = [
    str(p) for p in Path(data_root_dir).glob("CT-[1-4]-0.25/*.tfrecord")
]
pos_x = (
    tf.data.TFRecordDataset(pos_tfrecord_fnames)
    .map(example_to_tensor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(add_channel_axis, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # .take(254)
)
pos_x

In [None]:
unsupervised_train_ds, unsupervised_val_ds = train_test_split(
    neg_x.concatenate(pos_x),
    test_perc=unsupervised_val_perc,
    cardinality=None,
    seed=seed,
)
unsupervised_val_ds = unsupervised_val_ds.padded_batch(batch_size, input_shape)
val_ds_gen = duplicate_iterator(unsupervised_val_ds)
unsupervised_val_ds = (
    tf.data.Dataset.from_generator(lambda: val_ds_gen, (tf.float32, tf.float32))
    .cache()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
unsupervised_train_ds = unsupervised_train_ds.padded_batch(batch_size, input_shape)
train_ds_gen = duplicate_iterator(unsupervised_train_ds)
unsupervised_train_ds = (
    tf.data.Dataset.from_generator(lambda: train_ds_gen, (tf.float32, tf.float32))
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
unsupervised_train_ds

In [None]:
autoencoder = build_autoencoder(input_shape)
autoencoder.get_layer("encoder").summary()
autoencoder.get_layer("decoder").summary()
autoencoder.summary()

In [None]:
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.MeanSquaredError(),
)

start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
best_checkpoint = f"models/autoencoder/{start_time}.h5"
checkpoint_cb = keras.callbacks.ModelCheckpoint(best_checkpoint, save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=patience,
)
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=f"logs/autoencoder/{start_time}",
    histogram_freq=1,
    write_graph=False,
    profile_batch=0,
)
autoencoder.fit(
    unsupervised_train_ds,
    validation_data=unsupervised_val_ds,
    epochs=epochs,
    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb],
)
autoencoder = keras.models.load_model(best_checkpoint)

In [None]:
original, _ = next(iter(unsupervised_val_ds.skip(5)))
encoder_out = autoencoder.get_layer("encoder")(original, training=False)
decoder_out = autoencoder.get_layer("decoder")(encoder_out, training=False)
batch_index = 0
autoencoder.evaluate(unsupervised_val_ds, return_dict=True, verbose=0)

In [None]:
z_index = 5
fig, ax = plt.subplots(ncols=3)
plot_slice(original[batch_index, :], z_index, ax[0])
plot_slice(encoder_out[batch_index, :], encoder_out.shape[1] // 3, ax[1])
plot_slice(decoder_out[batch_index, :], z_index, ax[2])

In [None]:
num_neg = sum(1 for _ in neg_x)
neg_y = tf.data.Dataset.from_tensors(tf.constant([0], dtype=tf.int8)).repeat(num_neg)
neg_dataset = tf.data.Dataset.zip((neg_x, neg_y))
neg_dataset

In [None]:
num_pos = sum(1 for _ in pos_x)
pos_y = tf.data.Dataset.from_tensors(tf.constant([1], dtype=tf.int8)).repeat(num_pos)
pos_dataset = tf.data.Dataset.zip((pos_x, pos_y))
pos_dataset

In [None]:
dataset = neg_dataset.concatenate(pos_dataset)
dataset, test_dataset = train_test_split(
    dataset,
    test_perc=test_perc,
    cardinality=None,
    seed=seed,
)
test_dataset = test_dataset.padded_batch(1, (input_shape, (1,)))
train_dataset, val_dataset = train_test_split(
    dataset,
    test_perc=val_perc,
    cardinality=None,
    seed=seed,
)
val_dataset = (
    val_dataset.padded_batch(batch_size, (input_shape, (1,)), drop_remainder=True)
    .cache()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset = (
    train_dataset.padded_batch(batch_size, (input_shape, (1,)), drop_remainder=True)
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
train_dataset

In [None]:
from collections import Counter

print(Counter(label.numpy()[0] for _, label in train_dataset.unbatch()))
print(Counter(label.numpy()[0] for _, label in val_dataset.unbatch()))
print(Counter(label.numpy()[0] for _, label in test_dataset.unbatch()))

In [None]:
pretrained = True
if pretrained:
    encoder = keras.models.load_model(
        "models/autoencoder/20201020-161512.h5"
    ).get_layer("encoder")
else:
    encoder = build_encoder(input_shape)
encoder.summary()

In [None]:
cnn = keras.Sequential(
    [
        encoder,
        keras.layers.Flatten(),
        keras.layers.Dense(
            512,
            kernel_initializer="lecun_normal",
            bias_initializer="lecun_normal",
            activation="selu",
        ),
        keras.layers.AlphaDropout(0.1),
        keras.layers.Dense(1, activation="sigmoid"),
    ],
    name="cnn",
)
cnn.summary()

In [None]:
cnn.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[
        keras.metrics.TruePositives(name="tp"),
        keras.metrics.FalsePositives(name="fp"),
        keras.metrics.TrueNegatives(name="tn"),
        keras.metrics.FalseNegatives(name="fn"),
        keras.metrics.BinaryAccuracy(name="accuracy"),
        keras.metrics.Precision(name="precision"),
        keras.metrics.Recall(name="recall"),
        keras.metrics.AUC(name="auc"),
    ],
)

start_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
best_checkpoint = f"models/{'pretrained-' if pretrained else ''}{start_time}.h5"
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    best_checkpoint, monitor="val_auc", mode="max", verbose=1, save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_auc", patience=patience, mode="max"
)
tensorboard_cb = tf.keras.callbacks.TensorBoard(
    log_dir=f"logs/{'pretrained-' if pretrained else ''}{start_time}",
    histogram_freq=1,
    write_graph=False,
    profile_batch=0,
)
cnn.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb],
)
cnn = keras.models.load_model(best_checkpoint)

In [None]:
cnn = keras.models.load_models()
cnn.evaluate(val_dataset, verbose=0, return_dict=True)

In [None]:
x, y = next(iter(test_dataset.skip(2)))
prediction = cnn(x, training=False)
print(f"real: {y.numpy()}, prediction: {prediction.numpy()}")
plot_animated_volume(x[0, :], fps=5)