In [None]:
"""
# for colab
%pip install tensorflow==2.13
"""

In [40]:
import tensorflow as tf
import matplotlib.pyplot as plt
#import gc
#from google.colab import drive
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.patches as mpatches
from tensorflow.keras import backend as K
import numpy as np
import sys, os, platform

print('Tensorflow version:',tf.__version__)
print(f"Python Platform: {platform.platform()}")
print(f"Python {sys.version}")
gpu = len(tf.config.list_physical_devices("GPU")) > 0
print("GPU is", "available" if gpu else "NOT AVAILABLE")
%matplotlib inline

Tensorflow version: 2.13.0
Python Platform: macOS-13.4.1-arm64-arm-64bit
Python 3.11.4 (v3.11.4:d2340ef257, Jun  6 2023, 19:15:51) [Clang 13.0.0 (clang-1300.0.29.30)]
GPU is NOT AVAILABLE


In [None]:
# drive.mount('/content/drive/')

In [50]:
TARGET_SLICES = 192
DIMENSION = 256
PROJECT_FOLDER = "/Users/dutking/LOCAL/AI_uni/radlogix"  # "/content/drive/MyDrive/Dmitrii_Utkin/model_3dunet_custom"
DS_FOLDER = "dataset/tfrecords_1to1_deep_supervision"  # dataset
TRAIN_DS_FILE = "train.tfrecords"
VAL_DS_FILE = "val.tfrecords"
TEST_DS_FILE = "test.tfrecords"

In [51]:
def parse_train_record(record):
    name_to_features = {
        "shape": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "feature": tf.io.FixedLenFeature([], tf.string),
        "mini_label": tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(record, name_to_features)


def decode_train_record(record):
    feature = tf.io.decode_raw(
        record["feature"],
        out_type="float64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.io.decode_raw(
        record["label"],
        out_type="int16",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    shape = tf.io.decode_raw(
        record["shape"],
        out_type="int64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    mini_label = tf.io.decode_raw(
        record["mini_label"],
        out_type="int16",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.cast(tf.reshape(label, shape), dtype=tf.float32)
    feature = tf.cast(tf.reshape(feature, shape), dtype=tf.float32)
    mini_label = tf.cast(
        tf.reshape(
            mini_label,
            (int(TARGET_SLICES / 2), int(DIMENSION / 2), int(DIMENSION / 2), 1),
        ),
        dtype=tf.float32,
    )
    label.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    feature.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    mini_label.set_shape(
        (int(TARGET_SLICES / 2), int(DIMENSION / 2), int(DIMENSION / 2), 1)
    )
    return (feature, label, mini_label)


def parse_record(record):
    name_to_features = {
        "shape": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "feature": tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(record, name_to_features)


def decode_record(record):
    feature = tf.io.decode_raw(
        record["feature"],
        out_type="float64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.io.decode_raw(
        record["label"],
        out_type="int16",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    shape = tf.io.decode_raw(
        record["shape"],
        out_type="int64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.cast(tf.reshape(label, shape), dtype=tf.float32)
    feature = tf.cast(tf.reshape(feature, shape), dtype=tf.float64)
    label.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    feature.set_shape((TARGET_SLICES, DIMENSION, DIMENSION, 1))
    return (feature, label)

In [57]:
with tf.device("CPU"):
    train_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, TRAIN_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_train_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_train_record, num_parallel_calls=tf.data.AUTOTUNE)
    )

    # norm_ds = train_ds.map(lambda x, y, z: x)
    # normalization_layer = tf.keras.layers.Normalization(axis=None)
    # normalization_layer.adapt(norm_ds)
    # del norm_ds

    train_ds = (
        train_ds.cache()
        .shuffle(74, reshuffle_each_iteration=False)  # train_ds.cardinality().numpy()
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

    val_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, VAL_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_record, num_parallel_calls=tf.data.AUTOTUNE)
    )
    val_ds = (
        val_ds.cache()
        .shuffle(20, reshuffle_each_iteration=False)
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

    test_ds = (
        tf.data.TFRecordDataset(
            os.path.join(PROJECT_FOLDER, DS_FOLDER, TEST_DS_FILE),
            compression_type="GZIP",
        )
        .map(parse_record, num_parallel_calls=tf.data.AUTOTUNE)
        .map(decode_record, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .batch(1)
        .prefetch(tf.data.AUTOTUNE)
    )

In [78]:
tf.keras.saving.get_custom_objects().clear()


@tf.keras.saving.register_keras_serializable(name="tversky_index")
class TverskyIndex(tf.keras.metrics.Metric):
    def __init__(self, delta=0.5, smooth=0.000001, name="tversky_index", **kwargs):
        super(TverskyIndex, self).__init__(name=name, **kwargs)
        self.ti = self.add_weight(name="ti", initializer="zeros")
        self.count = self.add_weight(name="ti_c", initializer="zeros")
        self.delta = delta
        self.smooth = smooth

    @tf.function
    def identify_axis(self, shape):
        # Three dimensional
        if len(shape) == 5:
            return [1, 2, 3]
        # Two dimensional
        elif len(shape) == 4:
            return [1, 2]
        # Exception - Unknown
        else:
            raise ValueError("Metric: Shape of tensor is neither 2D or 3D.")

    @tf.function
    def update_state(self, y_true, y_pred, sample_weight=None):
        axis = self.identify_axis(y_true.get_shape())
        y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)  # MY
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)
        ti_class = (tp + self.smooth) / (
            tp + self.delta * fn + (1 - self.delta) * fp + self.smooth
        )
        # Average class scores
        ti = K.mean(ti_class)
        self.ti.assign_add(ti)
        self.count.assign_add(1.0)

    def result(self):
        if self.count == 0.0:
            return self.ti

        mean_ti = self.ti / self.count
        return mean_ti

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.count.assign(0.0)
        self.ti.assign(0.0)

    def get_config(self):
        config = {"delta": self.delta, "smooth": self.smooth}
        base_config = super().get_config()
        return {**base_config, **config}


@tf.keras.saving.register_keras_serializable(name="focal_tversky_loss")
class FocalTverskyLoss(tf.keras.losses.Loss):
    def __init__(self, delta=0.7, gamma=0.75, smooth=0.000001, name="ftl", **kwargs):
        super(FocalTverskyLoss, self).__init__(name=name, **kwargs)
        self.delta = delta
        self.gamma = gamma
        self.smooth = smooth

    @tf.function
    def identify_axis(self, shape):
        # Three dimensional
        if len(shape) == 5:
            return [1, 2, 3]
        # Two dimensional
        elif len(shape) == 4:
            return [1, 2]
        # Exception - Unknown
        else:
            raise ValueError("Metric: Shape of tensor is neither 2D or 3D.")

    def call(self, y_true, y_pred):
        axis = self.identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1 - y_pred), axis=axis)
        fp = K.sum((1 - y_true) * y_pred, axis=axis)

        tversky_class = (tp + self.smooth) / (
            tp + self.delta * fn + (1 - self.delta) * fp + self.smooth
        )
        # Average class scores
        ftl = K.mean(K.pow((1 - tversky_class), self.gamma))
        return ftl

    def get_config(self):
        config = {"delta": self.delta, "gamma": self.gamma, "smooth": self.smooth}
        base_config = super().get_config()
        return {**base_config, **config}


@tf.keras.saving.register_keras_serializable(name="ModelUnet3d")
class ModelUnet3d(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super(ModelUnet3d, self).__init__(*args, **kwargs)
        self.loss_weights = [1.0, 0.5]
        self.ti_metric = TverskyIndex()
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    def train_step(self, data):
        x, y, z = data

        with tf.GradientTape() as tape:
            y_preds = self(x, training=True)
            if type(y_preds) == list:
                y_pred, z_pred = y_preds
                loss = self.loss_weights[0] * self.compute_loss(y=y, y_pred=y_pred)
                loss += self.loss_weights[1] * self.compute_loss(
                    y=z, y_pred=z_pred
                )  # Deep Supervision Loss
                for metric in self.metrics:
                    if metric.name == "loss":
                        metric.update_state(loss)
                    else:
                        metric.update_state(y, y_preds[0])
            else:
                loss = self.compute_loss(y=y, y_pred=y_preds)
                for metric in self.metrics:
                    if metric.name == "loss":
                        metric.update_state(loss)
                    else:
                        metric.update_state(y, y_preds)
        trainable_vars = self.trainable_variables  # Network trainable parameters
        gradients = tape.gradient(loss, trainable_vars)  # Calculating gradients
        self.optimizer.apply_gradients(
            zip(gradients, trainable_vars)
        )  # Applying gradients to optimizer

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred, _ = self(x, training=False)
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # loss = self.loss_fn(y, y_pred)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def get_config(self):
        config = {
            "inputs": self.inputs,
            "ti_metric": self.ti_metric,
            "loss_tracker": self.loss_tracker,
        }
        base_config = super().get_config()
        return {**base_config, **config}

    @classmethod
    def from_config(cls, config):
        inputs_config = config.pop("inputs")
        inputs = keras.saving.deserialize_keras_object(inputs_config)
        ti_metric_config = config.pop("ti_metric")
        ti_metric = keras.saving.deserialize_keras_object(ti_metric_config)
        loss_tracker_config = config.pop("loss_tracker")
        loss_tracker = keras.saving.deserialize_keras_object(loss_tracker_config)
        return cls(inputs, ti_metric, loss_tracker, **config)

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.loss_tracker, self.ti_metric]  #


def conv_block(inputs, filters, activation="relu", kernel_initializer="he_normal"):
    """
    Convolution block of a UNet encoder
    """
    x = tf.keras.layers.Conv3D(
        filters,
        (3, 3, 3),
        padding="same",
        activation=activation,
        kernel_initializer=kernel_initializer,
    )(inputs)
    x = tf.keras.layers.Conv3D(
        filters,
        (3, 3, 3),
        padding="same",
        activation=activation,
        kernel_initializer=kernel_initializer,
    )(x)
    x = tf.keras.layers.BatchNormalization(axis=-1)(x)
    x = tf.keras.activations.relu(x)
    return x


def encoder_block(inputs, filters):
    """
    Encoder block of a UNet passes the result from the convolution block
    above to a max pooling layer
    """
    x = conv_block(inputs, filters)
    p = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(x)
    return x, p


def decoder_block(inputs, filters, concat_layer):
    # Upsample the feature maps
    x = tf.keras.layers.Conv3DTranspose(
        filters, (2, 2, 2), strides=(2, 2, 2), padding="same"
    )(inputs)
    x = tf.keras.layers.concatenate(
        [x, concat_layer]
    )  # Concatenation/Skip conncetion with conjuagte encoder
    x = conv_block(x, filters)  # Passed into the convolution block above
    return x


def create_model():
    input_shape = (192, 256, 256, 1)
    inputs = tf.keras.Input(shape=input_shape)
    # normalized_inputs = normalization_layer(inputs)
    d1, p1 = encoder_block(inputs, 1)  # 32
    d2, p2 = encoder_block(p1, 1)  # 64
    d3, p3 = encoder_block(p2, 1)  # 128
    d4, p4 = encoder_block(p3, 1)  # 200
    mid = conv_block(p4, 1)  # Midsection 256
    e2 = decoder_block(mid, 1, d4)  # Conjugate of encoder 4 - 200
    e3 = decoder_block(e2, 1, d3)  # Conjugate of encoder 3 - 128
    e4 = decoder_block(e3, 1, d2)  # Conjugate of encoder 2 - 64
    o1 = tf.keras.layers.Conv3D(1, (1, 1, 1), activation="sigmoid", name="preOutput")(
        e4
    )  # Output from 2nd last decoder (32,128,128,1)
    e5 = decoder_block(e4, 1, d1)  # Conjugate of encoder 1 - 32
    output = tf.keras.layers.Conv3D(
        1, (1, 1, 1), activation="sigmoid", name="finalOutput"
    )(e5)
    ml = ModelUnet3d(inputs=[inputs], outputs=[output, o1], name="Unet3D")
    return ml


@tf.keras.saving.register_keras_serializable(name="adamw")
def opt_adamw():
    return tf.keras.optimizers.AdamW(
        learning_rate=0.001,
        weight_decay=0.004,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-07,
        amsgrad=False,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        use_ema=True,
        ema_momentum=0.75,
        jit_compile=True,
        name="AdamW",
    )


# opt = tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=None, amsgrad=False, beta_1=0.9, beta_2=0.99)
# opt = tf.keras.optimizers.Adam(learning_rate=0.1)

model = create_model()
model.compile(optimizer=opt_adamw(), loss=FocalTverskyLoss())

model.summary()
# tf.keras.utils.plot_model(model, to_file='model_3dunet.png', show_shapes=True, show_dtype=False, show_layer_names=True, expand_nested=False, dpi=70,)



Model: "Unet3D"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_20 (InputLayer)       [(None, 192, 256, 256, 1)]   0         []                            
                                                                                                  
 conv3d_342 (Conv3D)         (None, 192, 256, 256, 1)     28        ['input_20[0][0]']            
                                                                                                  
 conv3d_343 (Conv3D)         (None, 192, 256, 256, 1)     28        ['conv3d_342[0][0]']          
                                                                                                  
 batch_normalization_171 (B  (None, 192, 256, 256, 1)     4         ['conv3d_343[0][0]']          
 atchNormalization)                                                                          

In [65]:
MODEL_NAME = "custom"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(PROJECT_FOLDER, f"saved_models/{MODEL_NAME}.keras"),
    save_weights_only=False,
    monitor="val_loss",
    mode="min",
    save_best_only=True,
)

callbacks = [
    tf.keras.callbacks.CSVLogger(
        os.path.join(PROJECT_FOLDER, f"logs/{MODEL_NAME}.csv")
    ),
    tf.keras.callbacks.TerminateOnNaN(),
    # tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5),
    model_checkpoint_callback,
]

In [76]:
model.predict(train_ds.take(1))
print("done")

2023-08-10 14:44:34.522731: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 34 of 74
2023-08-10 14:44:44.537509: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] Filling up shuffle buffer (this may take a while): 68 of 74
2023-08-10 14:44:46.323288: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] Shuffle buffer filled.
2023-08-10 14:44:46.693190: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


done


In [77]:
saved_model = tf.keras.models.save_model(
    model, os.path.join(PROJECT_FOLDER, "saved_models", f"{MODEL_NAME}.keras")
)

TypeError: Cannot serialize object KerasTensor(type_spec=TensorSpec(shape=(None, 192, 256, 256, 1), dtype=tf.float32, name='input_19'), name='input_19', description="created by layer 'input_19'") of type <class 'keras.src.engine.keras_tensor.KerasTensor'>. To be serializable, a class must implement the `get_config()` method.

In [20]:
best_model = tf.keras.models.load_model(
    os.path.join(PROJECT_FOLDER, "saved_models", f"{MODEL_NAME}.keras")
)
best_model.evaluate(test_ds, verbose=2)

TypeError: Unable to revive model from config. When overriding the `get_config()` method, make sure that the returned config contains all items used as arguments in the  constructor to <class '__main__.ModelUnet3d'>, which is the default behavior. You can override this default behavior by defining a `from_config(cls, config)` class method to specify how to create an instance of ModelUnet3d from its config.

Received config={'name': 'Unet3D', 'trainable': True}

Error encountered during deserialization: ModelUnet3d.__init__() missing 2 required positional arguments: 'inputs' and 'outputs'

In [None]:
history = model.fit(
    train_ds, validation_data=val_ds, verbose=2, epochs=2, callbacks=callbacks
)

model.evaluate(test_ds, verbose=2)

In [None]:
def plot_train_report(history):
    epochs = range(1, 131)

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    # Plot and label the training and validation loss values
    ax[0].plot(epochs, history.history["loss"], label="Training Loss")
    ax[0].plot(epochs, history.history["val_loss"], label="Validation Loss")

    # Add in a title and axes labels
    ax[0].set_title("FOCAL TVERSKY LOSS")
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Loss")

    # Set the tick locations
    ax[0].set_xticks(np.arange(1, 131, 10))

    # Display the plot
    ax[0].legend(loc="best")

    # Plot and label the training and validation loss values
    ax[1].plot(epochs, history.history["tversky_index"], label="Training TI")
    ax[1].plot(epochs, history.history["val_tversky_index"], label="Validation TI")

    # Add in a title and axes labels
    ax[1].set_title("TVERSKY INDEX")
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("TI")

    # Set the tick locations
    ax[1].set_xticks(np.arange(1, 131, 10))

    # Display the plot
    ax[1].legend(loc="best")
    plt.show()


def overlay_images(x, y, y_pred):
    mid_slice = int(np.round(x.shape[1] / 2))
    ti = TverskyIndex()

    x_slice = x[:, mid_slice, :, :, :]
    y_slice = y[:, mid_slice, :, :, :]
    y_pred_slice = y_pred[:, mid_slice, :, :, :]
    ti.update_state(y_slice, y_pred_slice)

    x_slice = np.squeeze(x_slice)
    y_slice = np.squeeze(y_slice)
    y_pred_slice = np.squeeze(y_pred_slice)

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    ax[0].imshow(x_slice, cmap="bone", interpolation="none")
    ax[0].imshow(y_slice, cmap="jet", alpha=(0.5 * y_slice))
    ax[1].imshow(x_slice, cmap="bone", interpolation="none")
    ax[1].imshow(y_pred_slice, cmap="jet", alpha=(0.5 * y_pred_slice))

    plt.title(f"Slice: {mid_slice}. Tversky Index: {ti.result().numpy()} ")
    plt.show()


def plot_3d(
    images,
    labels=["Y True", "Y Pred"],
    colors=[[0.5, 0.5, 1], [0.9, 0.1, 0.9]],
    alpha=[0.5, 0.5],
    threshold=[0, 0],
):
    fig = plt.figure(figsize=(12, 5))
    ax = fig.add_subplot(111, projection="3d")
    patches = []

    for idx, image in enumerate(images):
        # Расположим сканирование вертикально,
        # так чтобы голова пациента была вверху, лицом к камере
        image = image.transpose(2, 0, 1)
        image = image[:, :, ::-1]

        if idx == 0:
            # Устанавливаем границы для каждой оси в соответствии с формой транспонированного изображения
            ax.set_xlim(0, image.shape[0])
            ax.set_ylim(0, image.shape[1])
            ax.set_zlim(0, image.shape[2])

        # Получаем вершины и грани 3D модели, используя marching_cubes
        verts, faces, _, _ = measure.marching_cubes(image, threshold[idx])
        # Создаем объект Figure в matplotlib и добавляем в него 3D подзаголовок

        # Создаем коллекцию треугольников из вершин и граней, устанавливаем цвет и прозрачность, добавляем в подзаголовок
        mesh = Poly3DCollection(verts[faces], alpha=alpha[idx])
        face_color = colors[idx]
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)
        patches.append(mpatches.Patch(color=colors[idx], label=labels[idx]))
    ax.legend(handles=patches)

    # Отображаем визуализацию
    plt.show()


def plot_results(model, dataset):
    ftl = FocalTverskyLoss()
    for x, y, *z in dataset:
        ti = TverskyIndex()
        y_pred, _ = model.predict(x)
        y_pred = tf.where(y_pred >= 0.5, 1.0, 0.0)
        loss = ftl(tf.cast(y, dtype=tf.float32), y_pred)
        ti.update_state(tf.cast(y, dtype=tf.float32), y_pred)
        overlay_images(x, y, y_pred)
        x = tf.squeeze(x).numpy()
        y = tf.squeeze(y).numpy()
        y_pred = tf.squeeze(y_pred).numpy()
        print("FocalTverskyLoss:", loss.numpy(), "| TverskyIndex:", ti.result().numpy())
        if np.max(y) > 0:
            if np.max(y_pred) > 0:
                print(
                    f"CORRECT:  Existing effusion event FOUND: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                plot_3d([y, y_pred])
            else:
                print(
                    f"INCORRECT:  Existing effusion event NOT FOUND: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                plot_3d([y])
        else:
            if np.max(y_pred) > 0:
                print(
                    f"INCORRECT:  Non-existing effusion event FOUND +: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )
                plot_3d([y_pred])
            else:
                print(
                    f"CORRECT: Non-existing effusion event NOT FOUND -: {np.count_nonzero(y)}/{np.count_nonzero(y_pred)} voxels"
                )

        print("===== \n \n")

In [None]:
plot_train_report(history)
plot_results(model, test_ds)

In [None]:
best_model = tf.keras.models.load_model(
    os.path.join(PROJECT_FOLDER, "saved_models", f"{MODEL_NAME}.keras")
)
best_model.evaluate(test_ds, verbose=2)
plot_results(best_model, test_ds)