# Experiment

## Import

In [None]:
import pathlib
import random
from typing import Callable, List, Tuple

import keras
import matplotlib.pyplot as plt
import numpy as np
import pydicom as pdc
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow.keras import layers
from tqdm import tqdm

## Preprocessing

### Helper Function

In [None]:
def get_patient_path_pair(
    project_root_path: pathlib.Path, patient_list: List[int]
) -> List[Tuple[pathlib.Path, pathlib.Path, pathlib.Path]]:
    """Create a list of tuples that contain the InPhase, OutPhase, and ground truth PNG file paths for each patient.

    Args:
        project_root_path: The path to the root directory of the project.
        patient_list: A list of patient IDs.

    Returns:
        A list of tuples where each tuple contains the InPhase, OutPhase, and ground truth PNG file paths for a patient.
    """
    patient_path_pair = []
    for idx in tqdm(patient_list, desc="Creating Path Pairing"):
        mri_tdual_in_dcm_path = [
            path.as_posix()
            for path in sorted(
                list(project_root_path.rglob(f"MR/{idx}/T1DUAL/**/InPhase/*.dcm"))
            )
        ]
        mri_tdual_out_dcm_path = [
            path.as_posix()
            for path in sorted(
                list(project_root_path.rglob(f"MR/{idx}/T1DUAL/**/OutPhase/*.dcm"))
            )
        ]
        mri_tdual_ground_truth_path = [
            path.as_posix()
            for path in sorted(
                list(project_root_path.rglob(f"MR/{idx}/**/T1DUAL/Ground/*.png"))
            )
        ]
        for pair in zip(
            mri_tdual_in_dcm_path, mri_tdual_out_dcm_path, mri_tdual_ground_truth_path
        ):
            patient_path_pair.append(pair)
    return patient_path_pair


def get_train_val_test_split(split: str, patient_list: List[int]) -> dict:
    """Split the list of patients into train, validation, and test sets.

    Args:
        split: A string with three numbers separated by commas that represent the percentage of patients to include in the train, validation, and test sets, respectively.
        patient_list: A list of patient IDs.

    Returns:
        A dictionary with keys "train", "val", and "test" that each have a corresponding list of patient IDs.
    """
    random_patient_list = patient_list
    random.shuffle(random_patient_list)

    n = len(patient_list)

    n_train = int(n * (int(split[0]) / 10))
    n_val = int(n * (int(split[1]) / 10))
    n_test = int(n * (int(split[2]) / 10))

    train = random_patient_list[:n_train]
    val = random_patient_list[n_train : n - n_test]
    test = random_patient_list[n_train + n_val :]
    return {"train": train, "val": val, "test": test}


def get_dcm_img(dcm_path):
    """Load DICOM file and return pixel data .

    Args:
        dcm_path: The path to the DICOM file.

    Returns:
        The DICOM file with img.
    """
    image_bytes = tf.io.read_file(dcm_path)
    image = tfio.image.decode_dicom_image(image_bytes, dtype=tf.uint16)
    return image


def preprocess_dcm(dcm_path):
    img = get_dcm_img(dcm_path)
    # Normalize uint16 to [0,1]
    img_norm = img / tf.cast(65535, tf.uint16)
    return img_norm


def preprocess_ground(png_path):
    png = tf.io.read_file(png_path)
    png_array = tf.io.decode_png(png, channels=1)
    png_mask = tf.cast(tf.equal(png_array, 63), tf.uint8)
    return png_mask


def parsed_path_to_dataset(features):
    in_img = preprocess_dcm(features[0])
    out_img = preprocess_dcm(features[1])
    ground_truth = preprocess_ground(features[2])
    return (in_img, out_img), ground_truth


def load_dicom_image(dicom_path):
    """
    Load DICOM file from given path and decode it using TensorFlow I/O.

    Args:
        dicom_path (str): The path to the DICOM file.

    Returns:
        The decoded DICOM image.
    """
    image_bytes = tf.io.read_file(dicom_path)
    image = tfio.image.decode_dicom_image(image_bytes, dtype=tf.uint16)
    return image


def normalize_image(image):
    """
    Normalize the given image to the range [0, 1] by dividing it by the maximum value (2^16-1).

    Args:
        image (tf.Tensor): The image to be normalized.

    Returns:
        The normalized image.
    """
    max_value = tf.cast(65535, tf.uint16)
    normalized_image = tf.divide(image, max_value)
    return normalized_image


def load_ground_truth_mask(png_path):
    """
    Load the ground truth mask from the given PNG file path and convert it to binary mask.

    Args:
        png_path (str): The path to the PNG file.

    Returns:
        The binary mask of the ground truth.
    """
    png = tf.io.read_file(png_path)
    png_array = tf.io.decode_png(png, channels=1)
    ground_truth_mask = tf.cast(tf.equal(png_array, 63), tf.uint8)
    return ground_truth_mask


def parse_path_to_dataset(path_list):
    """
    Load and preprocess the DICOM image and the ground truth mask from the given path list.

    Args:
        path_list (list): A list of paths to DICOM and PNG files.

    Returns:
        A tuple of the preprocessed in-phase and out-of-phase images and the ground truth mask.
    """
    in_phase_image = load_dicom_image(path_list[0])
    out_phase_image = load_dicom_image(path_list[1])
    ground_truth_mask = load_ground_truth_mask(path_list[2])
    in_phase_image_norm = normalize_image(in_phase_image)
    out_phase_image_norm = normalize_image(out_phase_image)
    return (in_phase_image_norm, out_phase_image_norm), ground_truth_mask


def get_dataset(pair_path_list, batch_size):
    dataset = (
        tf.data.Dataset.from_tensor_slices(pair_path_list).map(
            parse_path_to_dataset, num_parallel_calls=tf.data.AUTOTUNE
        )
        # .shuffle(batch_size * 10)
        # .batch(batch_size)
        # .prefetch(tf.data.AUTOTUNE)
    )
    return dataset

In [None]:
# Get the project root path
project_root_path = pathlib.Path.cwd().parent

# Set a list of patient IDs
patient_list = [
    1,
    2,
    3,
    5,
    8,
    10,
    13,
    15,
    19,
    20,
    21,
    22,
    31,
    32,
    33,
    34,
    36,
    37,
    38,
    39,
]

# Get patient split
patient_split = get_train_val_test_split("721", patient_list)

# Get train, test, and val split
train_path_pair = get_patient_path_pair(project_root_path, patient_split["train"])
val_path_pair = get_patient_path_pair(project_root_path, patient_split["val"])
test_path_pair = get_patient_path_pair(project_root_path, patient_split["test"])

# Create Dataset
batch_size = 10
train_dataset = get_dataset(train_path_pair, batch_size)
val_dataset = get_dataset(val_path_pair, batch_size)
test_dataset = get_dataset(test_path_pair, batch_size)

## Model

## Helper Function

In [None]:
def conv_block(
    node_name: str,
    n_filter: int,
    batch_norm: bool,
    strides: int,
    kernel: int,
):
    def layer(input_tensor: tf.Tensor) -> tf.Tensor:
        """
        A pointwise convolution block with added linearity or non-linearity (ReLU6).

        Args:
            input_tensor (tf.Tensor): Input tensor to the layer.

        Returns:
            tf.Tensor: Output tensor of the layer.
        """
        x = layers.Conv2D(
            n_filter,
            kernel,
            strides=strides,
            padding="same",
            name=f"{node_name}_conv_block",
        )(input_tensor)
        if batch_norm:
            x = layers.BatchNormalization(name=f"{node_name}_conv_block_bnorm")(x)
        return x

    return layer


def pointwise_block(
    node_name: str,
    n_filter: int,
    batch_norm: bool,
    linear: bool,
    strides: int = 1,
    kernel: int = 1,
) -> Callable:
    """
    Returns a pointwise convolutional layer with optional batch normalization and activation.

    Args:
        node_name (str): Name of the layer.
        n_filter (int): Number of filters for the convolutional layer.
        batch_norm (bool): Whether to apply batch normalization.
        linear (bool): Whether to apply activation (relu6) after the convolutional layer.
        strides (int, optional): Stride of the convolutional layer. Defaults to 1.
        kernel (int, optional): Kernel size of the convolutional layer. Defaults to 1.

    Returns:
        callable: Function that returns the layer when called with an input tensor.
    """

    def layer(input_tensor: tf.Tensor) -> tf.Tensor:
        """
        A pointwise convolution block with added linearity or non-linearity (ReLU6).

        Args:
            input_tensor (tf.Tensor): Input tensor to the layer.

        Returns:
            tf.Tensor: Output tensor of the layer.
        """
        x = layers.Conv2D(
            n_filter, kernel, strides=strides, padding="same", name=f"{node_name}_pwise"
        )(input_tensor)
        if batch_norm:
            x = layers.BatchNormalization(name=f"{node_name}_pwise_bnorm")(x)
        if not linear:
            x = layers.Activation(tf.nn.relu6, name=f"{node_name}_pwise_relu6")(x)
        return x

    return layer


def depthwise_block(
    node_name: str, batch_norm: bool, strides: int = 1, kernel: int = 3
) -> Callable:
    """Create a depthwise convolution block with relu6 activation.

    Args:
        node_name (str): Name of the block.
        batch_norm (bool): Whether or not to apply batch normalization.
        strides (int, optional): The strides of the convolution along the height and width. Defaults to 1.
        kernel (int, optional): Integer, the size of the kernel to be used in depthwise convolution. Defaults to 3.

    Returns:
        Callable: A callable object that applies the depthwise convolution block to the input tensor.
    """

    def layer(input_tensor: tf.Tensor) -> tf.Tensor:
        """Applies depthwise convolution block to input_tensor.

        Args:
            input_tensor (tf.Tensor): Input tensor.

        Returns:
            tf.Tensor: Output tensor.
        """
        x = layers.DepthwiseConv2D(
            kernel, strides=strides, padding="same", name=f"{node_name}_dwise"
        )(input_tensor)
        if batch_norm:
            x = layers.BatchNormalization(name=f"{node_name}_dwise_bnorm")(x)
        x = layers.Activation(tf.nn.relu6, name=f"{node_name}_dwise_relu6")(x)
        return x

    return layer


def inverted_residual_bottleneck_block(
    node_name: str,
    n_filter: int,
    strides: int,
    t_expansion: int,
    batch_norm: bool,
    residual: bool = False,
) -> Callable:
    """
    A bottleneck block containing expansion and compression using pointwise and depthwise convolutions
    followed by optional residual connection.

    Args:
        node_name: A name for the block.
        n_filter: Number of filters in the output tensor.
        strides: Stride size of the depthwise convolution.
        t_expansion: Expansion factor for the number of filters in the expansion layer.
        batch_norm: Whether to apply batch normalization after each convolution.
        residual: Whether to apply a residual connection to the output tensor.

    Returns:
        A callable that takes an input tensor and returns the output tensor.
    """

    def layer(input_tensor: tf.Tensor) -> tf.Tensor:
        expanded_filter = keras.backend.int_shape(input_tensor)[-1] * t_expansion

        # Expansion layer
        x = pointwise_block(
            node_name=node_name + "_expand",
            n_filter=expanded_filter,
            batch_norm=batch_norm,
            linear=False,
        )(input_tensor)

        # Depthwise Layer
        x = depthwise_block(
            node_name=node_name + "_depthwise",
            batch_norm=batch_norm,
            strides=strides,
            kernel=3,
        )(x)

        # Compression layer
        x = pointwise_block(
            node_name=node_name + "_compress",
            n_filter=n_filter,
            batch_norm=batch_norm,
            strides=1,
            kernel=1,
            linear=True,
        )(x)

        if residual:
            x = layers.Add(name=f"{node_name}_add")([x, input_tensor])

        return x

    return layer


def sequence_inv_res_bot_block(
    node_name, n_filter, batch_norm, strides, t_expansion, n_iter
) -> Callable:
    """
    A layer containing a sequence of inverted
    residual bottleneck block that is repeated
    n_iter times

    Args:
        node_name (str):
        n_filter (int):
        batch_norm (bool):
        strides (int):
        t_expansion (int):
        n_iter (int):

    Returns:
        Callable: A callable Keras layer that applies the sequence of inverted residual bottleneck blocks to an input tensor.
    """

    def layer(input_tensor):
        x = inverted_residual_bottleneck_block(
            node_name=f"x_{node_name}_iter0",
            n_filter=n_filter,
            batch_norm=batch_norm,
            strides=strides,
            t_expansion=t_expansion,
            residual=False,
        )(input_tensor)

        for index in range(1, n_iter):
            x = inverted_residual_bottleneck_block(
                node_name=f"x_{node_name}_iter{index}",
                n_filter=n_filter,
                batch_norm=batch_norm,
                strides=1,
                t_expansion=t_expansion,
                residual=True,
            )(x)
        return x

    return layer


def upsample_block(
    node_name, n_filter, batch_norm, n_kernel=2, mode="upsample"
) -> Callable:
    """
    Upsample block containing different types of upsampling methods.

    Args:
        node_name (str): Name of the layer
        n_filter (int): Number of filters in the convolution layer
        batch_norm (bool): Whether to use batch normalization or not
        n_kernel (int): Kernel size of the convolution layer
        mode (str): Upsampling mode. Can be either "upsample" or "transpose".

    Returns:
        A Keras layer that performs upsampling.
    """
    if mode == "upsample":

        def layer(input_tensor):
            x = layers.UpSampling2D(size=2, name=f"x_{node_name}_upsample")(
                input_tensor
            )
            return x

    elif mode == "transpose":

        def layer(input_tensor):
            x = layers.Conv2DTranspose(
                filters=n_filter,
                kernel_size=n_kernel,
                strides=2,
                name=f"x_{node_name}_transpose",
                padding="same",
            )(input_tensor)

            # Batch Norm
            if batch_norm:
                x = layers.BatchNormalization(name=f"x_{node_name}_transpose_bn")(x)

            # Activation
            x = layers.Activation("relu", name=f"x_{node_name}_transpose_activation")(x)

            return x

    else:
        raise ValueError("Mode can only be either upsample or transpose")

    return layer

In [None]:
# Multi input (multimodal) MobileNetV2

input_in_phase = layers.Input(name="in_phase_input", shape=(256, 256, 1))
input_out_phase = layers.Input(name="out_phase_input", shape=(256, 256, 1))

x = layers.Concatenate()([input_in_phase, input_out_phase])
x = conv_block(node_name="x_0", n_filter=32, batch_norm=True, kernel=3, strides=2)(x)

# Encoder

x = sequence_inv_res_bot_block(
    node_name="enc_1", n_filter=16, batch_norm=True, strides=1, t_expansion=1, n_iter=1
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_2", n_filter=24, batch_norm=True, strides=2, t_expansion=6, n_iter=2
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_3", n_filter=32, batch_norm=True, strides=2, t_expansion=6, n_iter=3
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_4", n_filter=64, batch_norm=True, strides=2, t_expansion=6, n_iter=4
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_5", n_filter=96, batch_norm=True, strides=1, t_expansion=6, n_iter=3
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_6", n_filter=160, batch_norm=True, strides=2, t_expansion=6, n_iter=3
)(x)
x = sequence_inv_res_bot_block(
    node_name="enc_7", n_filter=320, batch_norm=True, strides=1, t_expansion=6, n_iter=1
)(x)

x = pointwise_block(
    node_name="mid_8",
    n_filter=1280,
    batch_norm=True,
    linear=False,
)(x)

# Decoder
x = upsample_block(
    node_name="dec_9", n_filter=320, batch_norm=True, n_kernel=2, mode="transpose"
)(x)

x = upsample_block(
    node_name="dec_10", n_filter=160, batch_norm=True, n_kernel=2, mode="transpose"
)(x)

x = upsample_block(
    node_name="dec_11", n_filter=96, batch_norm=True, n_kernel=2, mode="transpose"
)(x)

x = upsample_block(
    node_name="dec_12", n_filter=64, batch_norm=True, n_kernel=2, mode="transpose"
)(x)

x = upsample_block(
    node_name="dec_13", n_filter=32, batch_norm=True, n_kernel=2, mode="transpose"
)(x)

x = layers.Conv2D(1, 3, strides=1, padding="same", name=f"output_pwise")(x)

x = layers.BatchNormalization(name="output_pwise_bnorm")(x)
x = layers.Activation("softmax", name=f"output_pwise_softmax")(x)


model = tf.keras.Model(
    [input_in_phase, input_out_phase], x, name="multimodal_encoder_decoder_mobilenetV2"
)
# plot_model(model, to_file="images/MobileNetv2.png", show_shapes=True)

In [None]:
model.summary()

In [None]:
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.binary_focal_crossentropy,
    metrics=[tf.keras.metrics.binary_focal_crossentropy, "acc"],
)

In [None]:
model.fit(x=train_dataset, validation_data=val_dataset, epochs=5, batch_size=10)