In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
import gc
from google.colab import drive
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.patches as mpatches
from tensorflow.keras import backend as K
import numpy as np
import sys, os, platform

print('Tensorflow version:',tf.__version__)
print(f"Python Platform: {platform.platform()}")
print(f"Python {sys.version}")
gpu = len(tf.config.list_physical_devices("GPU")) > 0
print("GPU is", "available" if gpu else "NOT AVAILABLE")
%matplotlib inline

Tensorflow version: 2.13.0-rc1 
At least 2.13 required
Python Platform: macOS-13.4.1-arm64-arm-64bit
Tensor Flow Version: 2.13.0-rc1
Python 3.10.11 (v3.10.11:7d4cc5aa85, Apr  4 2023, 19:05:19) [Clang 13.0.0 (clang-1300.0.29.30)]
GPU is NOT AVAILABLE


In [None]:
# drive.mount('/content/drive/')

In [5]:
TARGET_SLICES = 64
DIMENSION = 256
PROJECT_FOLDER = "/content/drive/MyDrive/Dmitrii_Utkin/model_3dunet"  # "/Users/dutking/LOCAL/AI_uni/radlogix"
DS_FOLDER = "dataset"  # "dataset/tfrecords_1to1"
TRAIN_DS_FILE = "train.tfrecords"
VAL_DS_FILE = "val.tfrecords"
TEST_DS_FILE = "test.tfrecords"

In [6]:
def parse_record(record):
    name_to_features = {
        "shape": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "feature": tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(record, name_to_features)


def decode_record(record):
    feature = tf.io.decode_raw(
        record["feature"],
        out_type="float64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.io.decode_raw(
        record["label"],
        out_type="int16",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    shape = tf.io.decode_raw(
        record["shape"],
        out_type="int64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.cast(tf.reshape(label, shape), dtype=tf.float32)
    feature = tf.cast(tf.reshape(feature, shape), dtype=tf.float64)
    label.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    feature.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    return (feature, label)

In [7]:
with tf.device("CPU"):
    train_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, TRAIN_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_record, num_parallel_calls=tf.data.AUTOTUNE)
    )

    # norm_ds = train_ds.map(lambda x, y, z: x)
    # normalization_layer = tf.keras.layers.Normalization(axis=None)
    # normalization_layer.adapt(norm_ds)
    # del norm_ds

    train_ds = (
        train_ds.cache()
        .shuffle(74, reshuffle_each_iteration=False)  # train_ds.cardinality().numpy()
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

    val_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, VAL_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_record, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(20, reshuffle_each_iteration=False)
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

    test_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, TEST_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_record, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

In [8]:
# Helper function to enable loss function to be flexibly used for
# both 2D or 3D image segmentation - source: https://github.com/frankkramer-lab/MIScnn
def identify_axis(shape):
    # Three dimensional
    if len(shape) == 5:
        return [1, 2, 3]
    # Two dimensional
    elif len(shape) == 4:
        return [1, 2]
    # Exception - Unknown
    else:
        raise ValueError("Metric: Shape of tensor is neither 2D or 3D.")


def dice_coefficient(delta=0.5, smooth=0.000001):
    """The Dice similarity coefficient, also known as the Sørensen–Dice index or simply Dice coefficient, is a statistical tool which measures the similarity between two sets of data.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.5
    smooth : float, optional
        smoothing constant to prevent division by zero errors, by default 0.000001
    """

    def loss_function(y_true, y_pred):
        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        dice_class = (tp + smooth) / (tp + delta * fn + (1 - delta) * fp + smooth)
        # Average class scores
        dice = K.mean(dice_class)

        return dice

    return loss_function


def focal_tversky_loss(delta=0.7, gamma=0.75, smooth=0.000001):
    """A Novel Focal Tversky loss function with improved Attention U-Net for lesion segmentation
    Link: https://arxiv.org/abs/1810.07842
    Parameters
    ----------
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    """

    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        tversky_class = (tp + smooth) / (tp + delta * fn + (1 - delta) * fp + smooth)
        # Average class scores
        focal_tversky_loss = K.mean(K.pow((1 - tversky_class), gamma))

        return focal_tversky_loss

    return loss_function


def symmetric_focal_tversky_loss(delta=0.7, gamma=0.75):
    """This is the implementation for binary segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    """

    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        dice_class = (tp + epsilon) / (tp + delta * fn + (1 - delta) * fp + epsilon)

        # calculate losses separately for each class, enhancing both classes
        back_dice = (1 - dice_class[:, 0]) * K.pow(1 - dice_class[:, 0], -gamma)
        fore_dice = (1 - dice_class[:, 1]) * K.pow(1 - dice_class[:, 1], -gamma)

        # Average class scores
        loss = K.mean(tf.stack([back_dice, fore_dice], axis=-1))
        return loss

    return loss_function


def asymmetric_focal_tversky_loss(delta=0.7, gamma=0.75):
    """This is the implementation for binary segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    """

    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        dice_class = (tp + epsilon) / (tp + delta * fn + (1 - delta) * fp + epsilon)

        # calculate losses separately for each class, only enhancing foreground class
        back_dice = 1 - dice_class[:, 0]
        fore_dice = (1 - dice_class[:, 1]) * K.pow(1 - dice_class[:, 1], -gamma)

        # Average class scores
        loss = K.mean(tf.stack([back_dice, fore_dice], axis=-1))
        return loss

    return loss_function

In [13]:
def conv_block(input, filters, activation="relu"):
    """
    Convolution block of a UNet encoder
    """
    x = tf.keras.layers.Conv3D(
        filters, (3, 3, 3), padding="same", activation=activation
    )(input)
    x = tf.keras.layers.Conv3D(
        filters, (3, 3, 3), padding="same", activation=activation
    )(x)
    x = tf.keras.layers.BatchNormalization(axis=-1)(x)
    x = tf.keras.activations.relu(x)
    return x


def encoder_block(input, filters):
    """
    Encoder block of a UNet passes the result from the convolution block
    above to a max pooling layer
    """
    x = conv_block(input, filters)
    p = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    return x, p


def attention_gate(g, s, filters):  #! MOD
    Wg = tf.keras.layers.Conv3D(filters, (1, 1, 1), padding="same")(g)
    Wg = tf.keras.layers.BatchNormalization()(Wg)

    Ws = tf.keras.layers.Conv3D(filters, (1, 1, 1), padding="same")(s)
    Ws = tf.keras.layers.BatchNormalization()(Ws)

    out = tf.keras.activations.relu(Wg + Ws)
    out = tf.keras.layers.Conv3D(filters, (1, 1, 1), padding="same")(out)
    # out = tf.keras.activations.relu(out)
    out = tf.keras.activations.sigmoid(out)  #! mod
    out = out * s
    # out = tf.keras.layers.BatchNormalization()(out) #! mod

    return out


def decoder_block(input, filters, concat_layer):
    # Upsample the feature maps
    x = tf.keras.layers.Conv3DTranspose(
        filters, (2, 2, 2), strides=(2, 2, 2), padding="same"
    )(input)
    concat_layer = attention_gate(x, concat_layer, filters)
    x = tf.keras.layers.concatenate(
        [x, concat_layer]
    )  # Concatenation/Skip conncetion with conjugate encoder
    x = conv_block(x, filters)  # Passed into the convolution block above
    return x


def create_model():
    input_shape = (64, 256, 256, 1)
    inputs = tf.keras.Input(shape=input_shape)
    # normalized_inputs = normalization_layer(inputs)
    d1, p1 = encoder_block(inputs, 64)
    d2, p2 = encoder_block(p1, 128)
    d3, p3 = encoder_block(p2, 256)
    d4, p4 = encoder_block(p3, 512)
    mid = conv_block(p4, 1024)  # Midsection
    e2 = decoder_block(mid, 512, d4)  # Conjugate of encoder 4
    e3 = decoder_block(e2, 256, d3)  # Conjugate of encoder 3
    e4 = decoder_block(e3, 128, d2)  # Conjugate of encoder 2
    e5 = decoder_block(e4, 64, d1)  # Conjugate of encoder 1
    outputs = tf.keras.layers.Conv3D(1, (1, 1, 1), activation="sigmoid")(
        e5
    )  # Final Output
    ml = tf.keras.Model(inputs=[inputs], outputs=outputs, name="Unet")
    return ml

In [None]:
# @tf.keras.saving.register_keras_serializable(name="adamw")
def opt_adamw():
    return tf.keras.optimizers.AdamW(
        learning_rate=0.001,
        weight_decay=0.004,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-07,
        amsgrad=False,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        use_ema=True,
        ema_momentum=0.75,
        jit_compile=True,
        name="AdamW",
    )


model = create_model()
model.compile(
    optimizer=opt_adamw(), loss=focal_tversky_loss(), metrics=[dice_coefficient()]
)
# model.summary()
# tf.keras.utils.plot_model(model, to_file='model_3dunet.png', show_shapes=True, show_dtype=False, show_layer_names=True, expand_nested=False, dpi=70,)

In [None]:
MODEL_NAME = "1to1ds_batch1_adamw0001_ema075"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(PROJECT_FOLDER, f"saved_models/{MODEL_NAME}.keras"),
    save_weights_only=True,
    monitor="val_loss",
    mode="min",
    save_best_only=True,
)

callbacks = [
    tf.keras.callbacks.CSVLogger(
        os.path.join(PROJECT_FOLDER, f"logs/{MODEL_NAME}.csv")
    ),
    tf.keras.callbacks.TerminateOnNaN(),
    # tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5),
    model_checkpoint_callback,
]

In [None]:
model.evaluate(test_ds, verbose=2)

In [2]:
def plot_train_report(history):
    epochs = range(1, 151)

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    # Plot and label the training and validation loss values
    ax[0].plot(epochs, history.history["loss"], label="Training Loss")
    ax[0].plot(epochs, history.history["val_loss"], label="Validation Loss")

    # Add in a title and axes labels
    ax[0].set_title("FOCAL TVERSKY LOSS")
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Loss")

    # Set the tick locations
    ax[0].set_xticks(np.arange(1, 151, 10))

    # Display the plot
    ax[0].legend(loc="best")

    # Plot and label the training and validation TI values
    ax[1].plot(epochs, history.history["tversky_index"], label="Training TI")
    ax[1].plot(epochs, history.history["val_tversky_index"], label="Validation TI")

    # Add in a title and axes labels
    ax[1].set_title("TVERSKY INDEX")
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("TI")

    # Set the tick locations
    ax[1].set_xticks(np.arange(1, 151, 10))

    # Display the plot
    ax[1].legend(loc="best")

    plt.show()


def overlay_images(x, y, y_pred):
    mid_slice = int(np.round(x.shape[1] / 2))
    ti = TverskyIndex()

    x_slice = x[:, mid_slice, :, :, :]
    y_slice = y[:, mid_slice, :, :, :]
    y_pred_slice = y_pred[:, mid_slice, :, :, :]
    ti.update_state(y_slice, y_pred_slice)

    x_slice = np.squeeze(x_slice)
    y_slice = np.squeeze(y_slice)
    y_pred_slice = np.squeeze(y_pred_slice)

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    ax[0].imshow(x_slice, cmap="bone", interpolation="none")
    ax[0].imshow(y_slice, cmap="jet", alpha=(0.5 * y_slice))
    ax[1].imshow(x_slice, cmap="bone", interpolation="none")
    ax[1].imshow(y_pred_slice, cmap="jet", alpha=(0.5 * y_pred_slice))

    plt.title(f"Slice: {mid_slice}. Tversky Index: {ti.result().numpy()} ")
    plt.show()


def plot_3d(
    images,
    labels=["Y True", "Y Pred"],
    colors=[[0.5, 0.5, 1], [0.9, 0.1, 0.9]],
    alpha=[0.5, 0.5],
    threshold=[0, 0],
):
    fig = plt.figure(figsize=(12, 5))
    ax = fig.add_subplot(111, projection="3d")
    patches = []

    for idx, image in enumerate(images):
        # Расположим сканирование вертикально,
        # так чтобы голова пациента была вверху, лицом к камере
        image = image.transpose(2, 0, 1)
        image = image[:, :, ::-1]

        if idx == 0:
            # Устанавливаем границы для каждой оси в соответствии с формой транспонированного изображения
            ax.set_xlim(0, image.shape[0])
            ax.set_ylim(0, image.shape[1])
            ax.set_zlim(0, image.shape[2])

        # Получаем вершины и грани 3D модели, используя marching_cubes
        verts, faces, _, _ = measure.marching_cubes(image, threshold[idx])
        # Создаем объект Figure в matplotlib и добавляем в него 3D подзаголовок

        # Создаем коллекцию треугольников из вершин и граней, устанавливаем цвет и прозрачность, добавляем в подзаголовок
        mesh = Poly3DCollection(verts[faces], alpha=alpha[idx])
        face_color = colors[idx]
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)
        patches.append(mpatches.Patch(color=colors[idx], label=labels[idx]))
    ax.legend(handles=patches)

    # Отображаем визуализацию
    plt.show()


def plot_results(model, dataset):
    VOXELS_THRESHOLD = 30
    ftl = FocalTverskyLoss()
    correct_predictions = [0]
    for x, y, *z in dataset:
        ti = TverskyIndex()
        y_pred, _ = model.predict(x)
        y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)
        loss = ftl(tf.cast(y, dtype=tf.float32), y_pred)
        ti.update_state(tf.cast(y, dtype=tf.float32), y_pred)
        overlay_images(x, y, y_pred)
        print("\n\n")
        x = tf.squeeze(x).numpy()
        y = tf.squeeze(y).numpy()
        y_pred = tf.squeeze(y_pred).numpy()
        print("FocalTverskyLoss:", loss.numpy(), "| TverskyIndex:", ti.result().numpy())
        if np.max(y) > 0:
            if np.max(y_pred) > 0 and np.count_nonzero(y_pred) > VOXELS_THRESHOLD:
                print(
                    f"CORRECT:  Existing effusion event FOUND: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                correct_predictions[0] += 1
                plot_3d([y, y_pred])
            else:
                print(
                    f"INCORRECT:  Existing effusion event NOT FOUND: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                plot_3d([y], labels=["Y_TRUE"])
        else:
            if np.max(y_pred) > 0 and np.count_nonzero(y_pred) > VOXELS_THRESHOLD:
                print(
                    f"INCORRECT:  Non-existing effusion event FOUND +: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                plot_3d([y_pred], labels=["Y_PRED"])
            else:
                print(
                    f"CORRECT: Non-existing effusion event NOT FOUND -: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                correct_predictions[0] += 1

        print("\n \n")
    print(
        f"TOTAL CORRECT PREDICTIONS (with threshold of {VOXELS_THRESHOLD} voxels): {correct_predictions[0]}/10"
    )
    print("\n \n")

In [None]:
model.evaluate(test_ds, verbose=2)
print("\n\n")
plot_train_report(history)
print("\n\n")
plot_results(model, test_ds)