In [1]:
import datetime
from statistics import mean
from pprint import pprint

import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

from data import tfrecord_dataset, train_test_split
from train import best_num_epochs, train_model
from layers import SeluConv3D, SeluDense
from plot import plot_slice, plot_volume_animation
from config import (
    LIDC_SMALL_NEG_TFRECORD,
    LIDC_SMALL_POS_TFRECORD,
    SMALL_PATCH_SHAPE,
    BIG_PATCH_SHAPE,
    SEED,
)

%matplotlib inline
plt.rcParams["figure.figsize"] = [15, 7]

In [2]:
# Hyperparameters
val_perc = 0.2
batch_size = 16
patience = 30
extra_epochs = 5
learning_rate = 1e-5

In [3]:
neg_x = tfrecord_dataset(LIDC_SMALL_NEG_TFRECORD)
num_neg_samples = sum(1 for _ in neg_x)
print(f"Number of negative samples: {num_neg_samples}")
neg_x

Number of negative samples: 375


<ParallelMapDataset shapes: (None, None, None, None), types: tf.float32>

In [4]:
neg_dataset = tf.data.Dataset.zip((neg_x, neg_x))
assert sum(1 for _ in neg_dataset) == num_neg_samples
neg_dataset

<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [5]:
pos_x = tfrecord_dataset(LIDC_SMALL_POS_TFRECORD)
num_pos_samples = sum(1 for _ in pos_x)
print(f"Number of posative samples: {num_pos_samples}")
pos_x

Number of posative samples: 379


<ParallelMapDataset shapes: (None, None, None, None), types: tf.float32>

In [6]:
pos_dataset = tf.data.Dataset.zip((pos_x, pos_x))
assert sum(1 for _ in pos_dataset) == num_pos_samples
pos_dataset

<ZipDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [7]:
dataset = neg_dataset.concatenate(pos_dataset)
num_samples = num_neg_samples + num_pos_samples
assert sum(1 for _ in dataset) == num_samples
dataset

<ConcatenateDataset shapes: ((None, None, None, None), (None, None, None, None)), types: (tf.float32, tf.float32)>

In [8]:
def build_model():
    encoder = keras.Sequential(
        [
            keras.layers.InputLayer(SMALL_PATCH_SHAPE, name="encoder_input"),
            SeluConv3D(filters=32, kernel_size=3, name="encoder_selu_conv3d_1"),
            keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_1"),
            SeluConv3D(filters=64, kernel_size=3, name="encoder_selu_conv3d_2"),
            keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_2"),
            SeluConv3D(filters=128, kernel_size=3, name="encoder_selu_conv3d_3"),
            keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_3"),
            SeluConv3D(filters=256, kernel_size=3, name="encoder_selu_conv3d_4"),
            keras.layers.MaxPooling3D((1, 2, 2), name="maxpool_4"),
        ],
        name="encoder",
    )
    decoder = keras.Sequential(
        [
            keras.layers.InputLayer(encoder.output_shape[1:], name="decoder_input"),
            keras.layers.UpSampling3D((1, 2, 2), name="upsampling_1"),
            SeluConv3D(filters=256, kernel_size=3, name="decoder_selu_conv3d_1"),
            keras.layers.UpSampling3D((1, 2, 2), name="upsampling_2"),
            SeluConv3D(filters=128, kernel_size=3, name="decoder_selu_conv3d_2"),
            keras.layers.UpSampling3D((1, 2, 2), name="upsampling_3"),
            SeluConv3D(filters=64, kernel_size=3, name="decoder_selu_conv3d_3"),
            keras.layers.UpSampling3D((1, 2, 2), name="upsampling_4"),
            SeluConv3D(filters=32, kernel_size=3, name="decoder_selu_conv3d_4"),
            keras.layers.Dense(1, activation="sigmoid", name="decoder_final_dense"),
        ],
        name="decoder",
    )

    autoencoder = keras.Sequential(
        [
            keras.layers.InputLayer(SMALL_PATCH_SHAPE, name="autoencoder_input"),
            encoder,
            decoder,
        ],
        name="autoencoder",
    )
    return autoencoder

In [9]:
def compile_model(model):
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate),
        loss=keras.losses.MeanSquaredError(),
    )

In [10]:
train_dataset, val_dataset = train_test_split(
    dataset,
    test_perc=val_perc,
    cardinality=num_samples,
)
val_dataset = val_dataset.batch(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
train_dataset = (
    train_dataset.batch(batch_size)
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

autoencoder = build_model()
compile_model(autoencoder)

log_dir = f"logs/autoencoder-lidc"

num_epochs = best_num_epochs(
    autoencoder,
    train_dataset,
    val_dataset,
    patience,
    "val_loss",
    log_dir,
    verbose_training=1,
    metric_mode="min",
)
f"{num_epochs = }"

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

Epoch 82/1000
Epoch 83/1000
Epoch 84/1000
Epoch 85/1000
Epoch 86/1000
Epoch 87/1000
Epoch 88/1000
Epoch 89/1000
Epoch 90/1000
Epoch 91/1000
Epoch 92/1000
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000
Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000
Epoch 106/1000
Epoch 107/1000
Epoch 108/1000
Epoch 109/1000
Epoch 110/1000
Epoch 111/1000
Epoch 112/1000
Epoch 113/1000
Epoch 114/1000
Epoch 115/1000
Epoch 116/1000
Epoch 117/1000
Epoch 118/1000
Epoch 119/1000
Epoch 120/1000
Epoch 121/1000
Epoch 122/1000
Epoch 123/1000
Epoch 124/1000
Epoch 125/1000
Epoch 126/1000
Epoch 127/1000
Epoch 128/1000
Epoch 129/1000
Epoch 130/1000
Epoch 131/1000
Epoch 132/1000
Epoch 133/1000
Epoch 134/1000
Epoch 135/1000
Epoch 136/1000
Epoch 137/1000
Epoch 138/1000
Epoch 139/1000
Epoch 140/1000
Epoch 141/1000
Epoch 142/1000
Epoch 143/1000
Epoch 144/1000
Epoch 145/1000
Epoch 146/1000
Epoch 147/1000
Epoch 148/1000
Epoch 149/100

Epoch 161/1000
Epoch 162/1000
Epoch 163/1000
Epoch 164/1000
Epoch 165/1000
Epoch 166/1000
Epoch 167/1000
Epoch 168/1000
Epoch 169/1000
Epoch 170/1000
Epoch 171/1000
Epoch 172/1000
Epoch 173/1000
Epoch 174/1000
Epoch 175/1000
Epoch 176/1000
Epoch 177/1000
Epoch 178/1000
Epoch 179/1000
Epoch 180/1000
Epoch 181/1000
Epoch 182/1000
Epoch 183/1000
Epoch 184/1000
Epoch 185/1000
Epoch 186/1000
Epoch 187/1000
Epoch 188/1000
Epoch 189/1000
Epoch 190/1000
Epoch 191/1000
Epoch 192/1000
Epoch 193/1000
Epoch 194/1000
Epoch 195/1000
Epoch 196/1000
Epoch 197/1000
Epoch 198/1000
Epoch 199/1000
Epoch 200/1000
Epoch 201/1000
Epoch 202/1000
Epoch 203/1000
Epoch 204/1000
Epoch 205/1000
Epoch 206/1000
Epoch 207/1000
Epoch 208/1000
Epoch 209/1000
Epoch 210/1000
Epoch 211/1000
Epoch 212/1000
Epoch 213/1000
Epoch 214/1000
Epoch 215/1000
Epoch 216/1000
Epoch 217/1000
Epoch 218/1000
Epoch 219/1000
Epoch 220/1000
Epoch 221/1000
Epoch 222/1000
Epoch 223/1000
Epoch 224/1000
Epoch 225/1000
Epoch 226/1000
Epoch 227/

Epoch 241/1000
Epoch 242/1000
Epoch 243/1000
Epoch 244/1000
Epoch 245/1000
Epoch 246/1000
Epoch 247/1000
Epoch 248/1000
Epoch 249/1000
Epoch 250/1000
Epoch 251/1000
Epoch 252/1000
Epoch 253/1000
Epoch 254/1000
Epoch 255/1000
Epoch 256/1000
Epoch 257/1000
Epoch 258/1000
Epoch 259/1000
Epoch 260/1000
Epoch 261/1000
Epoch 262/1000
Epoch 263/1000
Epoch 264/1000
Epoch 265/1000
Epoch 266/1000
Epoch 267/1000
Epoch 268/1000
Epoch 269/1000
Epoch 270/1000
Epoch 271/1000
Epoch 272/1000
Epoch 273/1000
Epoch 274/1000
Epoch 275/1000
Epoch 276/1000
Epoch 277/1000
Epoch 278/1000
Epoch 279/1000
Epoch 280/1000
Epoch 281/1000
Epoch 282/1000
Epoch 283/1000
Epoch 284/1000
Epoch 285/1000
Epoch 286/1000
Epoch 287/1000
Epoch 288/1000
Epoch 289/1000
Epoch 290/1000
Epoch 291/1000
Epoch 292/1000
Epoch 293/1000
Epoch 294/1000
Epoch 295/1000
Epoch 296/1000
Epoch 297/1000
Epoch 298/1000
Epoch 299/1000
Epoch 300/1000
Epoch 301/1000
Epoch 302/1000
Epoch 303/1000
Epoch 304/1000
Epoch 305/1000
Epoch 306/1000
Epoch 307/

Epoch 320/1000
Epoch 321/1000
Epoch 322/1000
Epoch 323/1000
Epoch 324/1000
Epoch 325/1000
Epoch 326/1000
Epoch 327/1000
Epoch 328/1000
Epoch 329/1000
Epoch 330/1000
Epoch 331/1000
Epoch 332/1000
Epoch 333/1000
Epoch 334/1000
Epoch 335/1000
Epoch 336/1000
Epoch 337/1000
Epoch 338/1000
Epoch 339/1000
Epoch 340/1000
Epoch 341/1000
Epoch 342/1000
Epoch 343/1000
Epoch 344/1000
Epoch 345/1000
Epoch 346/1000
Epoch 347/1000
Epoch 348/1000
Epoch 349/1000
Epoch 350/1000
Epoch 351/1000
Epoch 352/1000
Epoch 353/1000
Epoch 354/1000
Epoch 355/1000
Epoch 356/1000
Epoch 357/1000
Epoch 358/1000
Epoch 359/1000
Epoch 360/1000
Epoch 361/1000
Epoch 362/1000
Epoch 363/1000
Epoch 364/1000
Epoch 365/1000
Epoch 366/1000
Epoch 367/1000
Epoch 368/1000
Epoch 369/1000
Epoch 370/1000
Epoch 371/1000
Epoch 372/1000
Epoch 373/1000
Epoch 374/1000
Epoch 375/1000
Epoch 376/1000
Epoch 377/1000
Epoch 378/1000
Epoch 379/1000
Epoch 380/1000
Epoch 381/1000
Epoch 382/1000
Epoch 383/1000
Epoch 384/1000
Epoch 385/1000
Epoch 386/

Epoch 399/1000
Epoch 400/1000
Epoch 401/1000
Epoch 402/1000
Epoch 403/1000
Epoch 404/1000
Epoch 405/1000
Epoch 406/1000
Epoch 407/1000
Epoch 408/1000
Epoch 409/1000
Epoch 410/1000
Epoch 411/1000
Epoch 412/1000
Epoch 413/1000
Epoch 414/1000
Epoch 415/1000
Epoch 416/1000
Epoch 417/1000
Epoch 418/1000
Epoch 419/1000
Epoch 420/1000
Epoch 421/1000
Epoch 422/1000
Epoch 423/1000
Epoch 424/1000
Epoch 425/1000
Epoch 426/1000
Epoch 427/1000
Epoch 428/1000
Epoch 429/1000
Epoch 430/1000
Epoch 431/1000
Epoch 432/1000


'num_epochs = 401'

In [11]:
train_dataset = (
    dataset.batch(batch_size)
    .cache()  # must be called before shuffle
    .shuffle(buffer_size=64, reshuffle_each_iteration=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
autoencoder = build_model()
compile_model(autoencoder)
model_fname = f"models/autoencoder-lidc.h5"
autoencoder = train_model(
    autoencoder,
    train_dataset,
    num_epochs + extra_epochs,
    model_fname,
)

48


In [None]:
original, _ = next(iter(val_dataset.unbatch().batch(1).skip(6)))
encoder_out = autoencoder.get_layer("encoder")(original, training=False)
decoder_out = autoencoder.get_layer("decoder")(encoder_out, training=False)
plot_volume_animation(original[0, :])

In [None]:
batch_index = 0
z_index = 5
fig, ax = plt.subplots(ncols=3)
plot_slice(original[batch_index, :], index=z_index, ax=ax[0])
plot_slice(encoder_out[batch_index, :], index=z_index, ax=ax[1])
plot_slice(decoder_out[batch_index, :], index=z_index, ax=ax[2])

In [None]:
plot_animated_volume(original[0, :])