# U-Net and Associated Functions

## Overview
This notebook contains the  U-Net and it's associated functions. All functions contain doc strings which define the function inputs and outputs, and their data types. There are also comments throughout each function for additional clarity.

## Structure
- **[Import Necessary Packages](#import-necessary-packages)**: Importing the necessary packages used by functions in this notebook.
- **[Internal Functions](#internal-functions)**: Internal or private functions that are only called by other functions and should never be called directly by a script. 
    - [`_get_colors`](#get-colors): 
    - [`_get_class_labels`](#get-class-labels): 
    - [`_get_color_masks`](#get-colored-masks): 
    - [`_compute_basic_metrics`](#compute-basic-metrics):
    - [`_view_predictions`](#view-predictions):
- **[Public Functions](#public-functions)**: Public functions to be imported and called by other scripts, outside of this notebook. 
    - [`iou_metric`](#intersection-over-union): 
    - [`dice_coeff`](#dice-coefficient): 
    - [`unet_model`](#u-net-model-definition): 
    - [`plot_losses`](#plot-model-losses): 
    - [`test_unet`](#test-trained-u-net): 
    - [`save_predicted_masks`](#save-predicted-masks): 


## Usage
This notebook is not intended to be run individually. The public functions in this notebook should be imported into the `main.ipynb` using the `import-ipynb` package. The purpose of this notebook is to clearly separate and define each U-Net-related function that is called internally (in this notebook by one of the other functions) and each U-Net-related function that is called in `main.ipynb`.

## Import Necessary Packages

We first import the necessary packages used by the functions in this notebook.

In [None]:
# Necessary imports
import os
from pathlib import Path
import warnings
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as mpatches
import numpy as np
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score
from typing import Optional, Dict, List, Tuple, Any, Union

# Optional line to suppress unnecessary tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from tensorflow.keras import layers, models, Model, backend

## Internal Functions

This section contains any internal, or private functions. These functions are ones that are called by other functions, and are not intended for use in any other manner. 

### Get Colors

In [None]:
# Internal function to get colors for the classes for display
def _get_colors(num_colors: int, color_map: str) -> Dict[int, np.ndarray]:
    """
    Generates a dictionary of colors for each class.

    Args:
        num_colors (int): The number of colors to generate.
        color_map (str): Matplotlib colormap.

    Returns:
        Dict[int, np.ndarray]: A dictionary where the key is the class index and the value is a numpy array representing the RGB colors.
    """
    # Select the color map from matplotlib's library, and keep only the RGB color information
    if color_map:
        try:
            colors = cm.get_cmap(color_map, num_colors)
        # If the color map is not recognized in matplotlib's color map, warn the user and use the default 
        except ValueError:
            warnings.warn(
                f"Unrecognized color map {color_map}.\n" "Will use default 'viridis'.\n",
                UserWarning,
            )
            colors = cm.get_cmap("viridis", num_colors)

    # If no colors are given, use the default
    else:
        colors = cm.get_cmap("viridis", num_colors)

    # Extract the RGB values from the colormap
    colors = colors(np.arange(num_colors))[:, :3]

    # Create a dictionary of class colors
    class_colors = {idx: colors[idx] for idx in range(num_colors)}
    return class_colors



### Get Class Labels

In [None]:
# Internal function to get or assign class labels
def _get_class_labels(
    class_labels: Union[Dict[int, str], None], num_classes: int
) -> Dict[int, str]:
    """
    Ensures the class labels dictionary is correctly set up. If class_labels is None,
    generates a dictionary of class labels.

    Args:
        class_labels (Optional[Dict[int, str]]): An optional dictionary of class labels.
        num_classes (int): The number of classes.

    Returns:
        Dict[int, str]: A dictionary where the key is the class index and the value is the class label as a string.
    """
    # If class_labels is not provided, generate default labels based on the number of classes
    if not class_labels:
        num_labels = 2 if num_classes == 1 else num_classes
        class_labels = {idx: str(idx) for idx in range(num_labels)}

    # Handle the case where there is only one class label but two classes (binary classification)
    elif len(class_labels) == 1 and num_classes == 2:
        key, value = next(iter(class_labels.items()))
        background_key = 1 if key == 0 else 0
        class_labels[background_key] = "Assigned Background"

    # If the number of class labels does not match the number of classes, reset the labels
    elif len(class_labels) != num_classes and not (
        len(class_labels) == 2 and num_classes == 1
    ):
        print(
            f"Number of class labels ({len(class_labels)}) does not match the number of classes ({num_classes}). Labels will be changed to class index numbers."
        )
        num_labels = 2 if num_classes == 1 else num_classes
        class_labels = {idx: str(idx) for idx in range(num_labels)}

    return class_labels

### Get Colored Masks

In [None]:
# Internal function to color the masks and the predicted masks based on the color map and class labels
def _get_color_masks(
    num_classes: int,
    masks: np.ndarray,
    predicted_masks: np.ndarray,
    class_labels: Dict[int, str],
    class_color_map: str,
) -> Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]:
    """
    Colors the masks and predicted masks for visualization.

    Args:
        num_classes (int): The number of classes.
        masks (np.ndarray): The ground truth masks.
        predicted_masks (np.ndarray): The predicted masks.
        class_labels (Dict[int, str]): A dictionary where the key is the class index (int) and the value is a class label.
        class_color_map (str): Matplotlib colormap.

    Returns:
        Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]: A tuple containing:
            - colored_masks (np.ndarray): The colored ground truth masks.
            - colored_preds (np.ndarray): The colored predicted masks.
            - class_colors_and_labels (Dict[int, Dict[str, Union[str, np.ndarray]]]): A dictionary where the key is the class index and each class index contains a color and label.
    """
    # Determine the number of colors needed based on the number of classes
    if num_classes == 1:
        num_colors = 2 # One color for background, one color for foreground
    else:
        num_colors = num_classes

    # Get the colors for each class using the specified colormap
    class_colors = _get_colors(num_colors, class_color_map)

    # Create a dictionary to store the color and label for each class
    class_colors_and_labels = {}
    for key in class_colors:
        class_colors_and_labels[key] = {
            "color": class_colors[key],
            "label": class_labels[key],
        }

    # Create a color map that can easily be applied to the masks and predictions
    color_map = np.array(
        [class_colors.get(color, [0, 0, 0]) for color in range(num_colors)],
        dtype=np.float32,
    )

    # Apply the color map to the ground truth masks and predicted masks
    colored_masks = color_map[masks]
    colored_preds = color_map[predicted_masks]

    return colored_masks, colored_preds, class_colors_and_labels



### Compute Basic Metrics

In [None]:
# Internal function to compute some basic metrics that asses model performance on the test data. 
def _compute_basic_metrics(
    masks: np.ndarray, predicted_masks: np.ndarray, metric_ave_method: str
) -> Dict[str, float]:
    """
    Computes basic metrics (precision, recall, and F1 score) for the given masks and predicted masks.

    Args:
        masks (np.ndarray): The ground truth masks.
        predicted_masks (np.ndarray): The predicted masks.
        metric_ave_method (str): Averaging method to use in the metrics.

    Returns:
        Dict[str, float]: A dictionary containing the precision, recall, and F1 score.
    """
    # Flatten the masks and predictions as necessary for the scikit-learn library metrics
    masks = masks.flatten()
    predicted_masks = predicted_masks.flatten()
    # Precision: Calculates the precision, which is the ratio of correctly predicted positive observations to the total predicted positives. It is a measure of the accuracy of the positive predictions.
    # Precision = true_positives / (true_positives + false_positives)
    precision = precision_score(
        masks, predicted_masks, zero_division=0, average=metric_ave_method
    )
    # Recall: Calculates the recall, also known as sensitivity or true positive rate, which is the ratio of correctly predicted positive observations to all the actual positives. It is a measure of how well the model can capture positive instances.
    # Recall = true_positives / (true_positives + false_negatives)
    recall = recall_score(
        masks, predicted_masks, zero_division=0, average=metric_ave_method
    )
    # F1 Score: Calculates the F1 score, which is the weighted average of precision and recall. It considers both false positives and false negatives and is useful when you need a balance between precision and recall.
    # F1 Score = 2 * (Precision * Recall) / (Precision + Recall)
    f1 = f1_score(masks, predicted_masks, zero_division=0, average=metric_ave_method)
    return {"precision": precision, "recall": recall, "f1 score": f1}



### View Predictions

In [None]:
# Internal function to view the predictions based on the test data
def _view_predictions(
    images: np.ndarray,
    colored_masks: np.ndarray,
    colored_preds: np.ndarray,
    image_names: List[str],
    class_colors_and_labels: Dict[int, Dict[str, Union[str, np.ndarray]]],
    display_count: int,
    images_per_figure: int,
) -> None:
    """
    Displays the images, their masks, and the model predictions of the masks.

    Args:
        images (np.ndarray): The input images.
        colored_masks (np.ndarray): The colored ground truth masks.
        colored_preds (np.ndarray): The colored predicted masks.
        image_names (List[str]): List of image names.
        class_colors_and_labels (Dict[int, Dict[str, Union[str, np.ndarray]]]): A dictionary where the key is the class index and each class index contains a color and label.
        display_count (int): The number of images to display.
        images_per_figure (int): The number of images per figure.

    Returns:
        None: This function does not return anything. It displays the figures.
    """
    # Determine the number of images to display
    img_cnt = min(display_count, len(images))
    
    # If image names are not provided, create names based on the number of images
    if image_names is None:
        image_names = [str(idx) for idx in range(len(images))]

    # Create legend patches (labels) for each mask class
    patches = [
        mpatches.Patch(
            color=class_colors_and_labels[idx]["color"],
            label=class_colors_and_labels[idx]["label"],
        )
        for idx in class_colors_and_labels
    ]

    # Calculate the number of figures needed
    num_figs = (img_cnt + images_per_figure - 1) // images_per_figure

    # Loop over each figure
    for fig_idx in range(num_figs):

        # Determine the number of images in the current figure
        img_in_fig = min(images_per_figure, img_cnt - fig_idx * images_per_figure)

        # Create subplots for the current figure
        fig, axes = plt.subplots(img_in_fig, 3, figsize=(10, 3 * img_in_fig))

        # Ensure axes is a 2D array even if there's only one row
        if img_in_fig == 1:
            axes = np.expand_dims(axes, axis=0)

        # Loop over each image to be included in the current figure
        for idx in range(img_in_fig):
            img_idx = (fig_idx * images_per_figure) + idx
            if img_idx >= img_cnt:
                break
            
            # Extract the image, mask, and predicted mask
            image, mask, prediction = (
                images[img_idx],
                colored_masks[img_idx],
                colored_preds[img_idx],
            )

            # Plot the image
            ax_img, ax_mask, ax_pred = axes[idx]
            ax_img.imshow(image)
            ax_img.axis("off")
            ax_img.set_title(f"Image {image_names[img_idx]}")

            # Plot the true mask
            ax_mask.imshow(mask)
            ax_mask.axis("off")
            ax_mask.set_title(f"Mask {image_names[img_idx]}")
            
            # Plot the predicted mask
            ax_pred.imshow(prediction)
            ax_pred.axis("off")
            ax_pred.set_title(f"Pred. Mask {image_names[img_idx]}")

        # Add legend to the figure using the color labels
        fig.legend(
            handles=patches,
            loc="upper right",
            bbox_to_anchor=(1, 1),
            bbox_transform=fig.transFigure,
        )
        # Adjust layout and display the figure without blocking execution
        plt.tight_layout()
        plt.show(block=False)
    return



## Public Functions

The functions in this section are public functions which can be called or passed outside this notebook, and are intended for use in the `main.ipynb`. 

### Intersection Over Union 

In [None]:
# Custom metric used when training the model. Intersection Over Union (IOU)
def iou_metric(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
    """
    Calculates the Intersection over Union (IoU) metric for the given true and predicted masks.

    Args:
        true_masks (np.ndarray): The ground truth masks.
        predicted_masks (np.ndarray): The predicted masks.
        smooth (float, optional): A small value to avoid division by zero. Default is 1e-8.

    Returns:
        float: The IoU metric.
    """
    # Flatten the true and predicted masks to 1D arrays
    true_masks = backend.flatten(true_masks)
    predicted_masks = backend.flatten(predicted_masks)

    # Calculate the intersection of the true and predicted masks
    intersection = backend.sum(true_masks * predicted_masks)

    # Calculate the union of the true and predicted masks
    union = backend.sum(true_masks) + backend.sum(predicted_masks) - intersection

    # Calculate the Intersection Over Union (IOU)
    iou = (intersection + smooth) / (union + smooth)
    return iou


### Dice Coefficient

In [None]:
# Custom metric used when training the model. Dice Coefficient.
def dice_coeff(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
    """
    Calculates the Dice Coefficient for the given true and predicted masks.

    Args:
        true_masks (np.ndarray): The ground truth masks.
        predicted_masks (np.ndarray): The predicted masks.
        smooth (float, optional): A small value to avoid division by zero. Default is 1e-8.

    Returns:
        float: The Dice Coefficient.
    """
    
    # Flatten the true and predicted masks to 1D arrays
    true_masks = backend.flatten(true_masks)
    predicted_masks = backend.flatten(predicted_masks)

    # Calculate the intersection of the true and predicted masks
    intersection = backend.sum(true_masks * predicted_masks)

    # Calculate the Dice Coefficient
    dice = (2.0 * intersection + smooth) / (
        backend.sum(true_masks) + backend.sum(predicted_masks) + smooth
    )
    return dice



### U-Net Model Definition

In [None]:
# Function ot define teh actual U-Net model
def unet_model(
    input_shape: Tuple[int, int, int],
    num_classes: int,
    num_blocks: int = 4,
    optimizer: str = "adam",
    metrics: List[str] = ["accuracy"],
):
    """
    Builds a U-Net model for image segmentation.

    Args:
        input_shape (Tuple[int, int, int]): The shape of the input images (height, width, channels).
        num_classes (int): The number of output classes for segmentation.
        num_blocks (int, optional): The number of encode/decode blocks in the U-Net. Default is 4.
        optimizer (str, optional): Optimization method to use for the model. The default is "adam".
        metrics (List[str], optional): Optional metrics used in model training improvements. Default is ["accuracy"] .

    Returns:
        Model: A TensorFlow Keras Model representing the U-Net.
    """
    # Create input layer
    inputs = layers.Input(shape=input_shape)

    # Initialize components for subsequent encoder and decoder layers
    encoder_layers = []
    x = inputs  # updates at every layer
    filters = 64  # Initial number of filters, updates at every layer block
    
    # For each encoder block create 3 layers: 2 ReLU convolution layers, and 1 MaxPooling layer. Also store the result from te 2 convolution layers for skipping when decoding.
    for idx in range(num_blocks):
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        encoder_layers.append(x)  # Store the result for skip connections
        x = layers.MaxPooling2D(2)(x)
        filters *= 2 # Double the number of filters for the next block

    # Bottleneck block: 2 ReLU convolution layers
    x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
    x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)

    # For each decoder block create 4 layers: 1 UpSampling, 1 Concatentation with the skipped encoder layer from the same level, and 2 ReLU convolution layers.
    for idx in reversed(range(num_blocks)):
        filters //= 2 # Halve the number of filters for the next block
        x = layers.UpSampling2D(size=2)(x)
        x = layers.concatenate([x, encoder_layers[idx]], axis=-1) # Concatenate with the corresponding encoder layer
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)

    # Select the appropriate activation and loss functions based on the number of classes in the input data.
    if num_classes == 1:
        activation_fn = "sigmoid"
        loss_fn = "binary_crossentropy"
    else:
        activation_fn = "softmax"
        loss_fn = "categorical_crossentropy"

    # Output layer
    outputs = layers.Conv2D(num_classes, 1, activation=activation_fn)(x)

    # Create the model as tensorflow NN Model.
    unet_model = models.Model(inputs, outputs)

    # Compile the model with the desired optimizer, loss function, and metrics
    unet_model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

    return unet_model



### Plot Model Losses

In [None]:
# Function ot plot the losses from the training process of the U-Net model
def plot_losses(model_fit: Any) -> None:
    """
    Plots the training and validation losses from the model fitting process.

    Args:
        model_fit (Any): The result of the model fitting process (e.g., the history object returned by `model.fit`).

    Returns:
        None: This function does not return anything. It plots the training and validation losses.
    """
    # Create a new figure with specified size
    plt.figure(figsize=(6, 4))

    # Plot the training loss
    plt.plot(model_fit.history["loss"], label="Train Loss")
    
    # Plot the validation loss
    plt.plot(model_fit.history["val_loss"], label="Validation Loss")

    # Set axes labels and limits
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.xlim(left=0)
    plt.ylim(bottom=0)

    # Set the title 
    plt.title("Training Loss vs. Validation Loss")

    # Add legend
    plt.legend()

    # Adjust fit for neatest, and don't block execution with the plot 
    plt.tight_layout()
    plt.show(block=False)
    return


### Test Trained U-Net

In [None]:
# Function to test the U-Net model after training
def test_unet(
    model: Model,
    images: np.ndarray,
    masks: np.ndarray,
    num_classes: int,
    threshold: float = 0.5,
    image_names: Optional[List[str]] = None,
    class_labels: Optional[Dict[int, str]] = None,
    class_color_map: Optional[str] = None,
    display_figures: bool = True,
    display_count: int = 3,
    images_per_figure: int = 3,
) -> None:
    """
    Tests the U-Net model by plotting the predictions on a set of images and their corresponding masks.

    Args:
        model (Model): The U-Net model to be tested.
        images (np.ndarray): The input images.
        masks (np.ndarray): The ground truth masks.
        num_classes (int): The number of classes.
        threshold (float, optional): Threshold to  use if the number of classes is 1 (binary). Default is 0.5.
        image_names (Optional[List[str]], optional): List of image names. Default is None.
        class_labels (Optional[Dict[int, str]], optional): A dictionary where the key is the class index and the value is a the class label. Default is None.
        display_figures (bool, optional): Whether to display the figures. Default is True.
        class_color_map (Optional[str]): Matplotlib colormap. Default is viridis.
        display_count (int, optional): The number of images to display. Default is 3.
        images_per_figure (int, optional): The number of images per figure. Default is 3.

    Returns:
        np.ndarray: The predicted masks.
    """

    # Predict masks using the U-Net model
    predicted_masks = model.predict(images)

    # Process predicted masks based on the number of classes
    if num_classes == 1:
        # Converting the predicted masks and masks back to the correct format for plotting and for performance metrics
        predicted_masks = (predicted_masks[..., 0] > threshold).astype(np.uint8)
        masks = (masks > threshold).astype(np.uint8)

        # Setting the method to be used in the metrics based on the number of classes
        metric_ave_method = "binary" 
    else:
        # Converting the predicted masks and masks back to the correct format for plotting and for performance metrics
        predicted_masks = np.argmax(predicted_masks, axis=-1).astype(np.uint8)
        masks = np.argmax(masks, axis=-1).astype(np.uint8)
        
        # Setting the method to be used in the metrics based on the number of classes
        metric_ave_method = None 

     # Get class labels, ensuring they are defined for potential figure legends
    class_labels = _get_class_labels(class_labels, num_classes)

    # Compute metrics for the predicted masks
    metrics = _compute_basic_metrics(masks, predicted_masks, metric_ave_method)
    print("Metrics based on test data:\n")
    
    # If multiple classes then print metrics for each class in the predicted masks
    if metric_ave_method != "binary":
        for idx in range(num_classes):

            # Verify that there are metrics for the given class index
            if 0 <= idx < len(metrics["precision"]):
                print(f"Class {class_labels[idx]} Individual Metrics:")
                print(f"Precision: {metrics['precision'][idx]}")
                print(f"Recall: {metrics['recall'][idx]}")
                print(f"F1 Score: {metrics['f1 score'][idx]}\n")
            else:
                print(f"Class {class_labels[idx]} Individual Metrics:")
                print(
                    f"Class {class_labels[idx]} not found in predicted masks. No metrics available."
                )

    # Else if only a background and foreground class (binary), print the metrics for the foreground
    else:
        print(f"Precision: {metrics['precision']}\n")
        print(f"Recall: {metrics['recall']}\n")
        print(f"F1 Score: {metrics['f1 score']}\n")

    # Display figures if the flag is set to True
    if display_figures:
        # Get the labels and colored masks
        colored_masks, colored_preds, class_colors_and_labels = _get_color_masks(
            num_classes, masks, predicted_masks, class_labels, class_color_map
        )

        # Display the figures that show the image, mask, and prediction
        _view_predictions(
            images,
            colored_masks,
            colored_preds,
            image_names,
            class_colors_and_labels,
            display_count,
            images_per_figure,
        )

    return predicted_masks



### Save Predicted Masks

In [None]:
def save_predicted_masks(
    predicted_masks: np.ndarray,
    output_dir: str,
    mask_names: Optional[List[str]] = None,
    file_ext: str = "tif",
) -> None:
    """
    Saves the predicted masks to the specified output directory.

    Args:
        predicted_masks (np.ndarray): Array containing the predicted masks.
        output_dir (str): Path to the directory where the masks will be saved.
        mask_names (Optional[List[str]], optional): List of mask names to save. If None, all masks will be saved. Default is None.
        file_ext (str, optional): File extension for the saved masks. Default is "tif".

    Returns:
        None: This function does not return anything.
    """
    # Create the output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Normalize file extension to lowercase and handle "tif" as "tiff"
    file_ext = file_ext.strip().lower()
    if file_ext == "tif":
        file_ext = "tiff"

    # If no mask names are provided, generate default names based on indices
    if not mask_names:
        mask_names = [str(idx) for idx in range(len(predicted_masks))]

    # Save each predicted mask with the corresponding name and file extension
    for predicted_mask, mask_name in zip(predicted_masks, mask_names):
        file_name = f"pred_{mask_name}.{file_ext}"
        Image.fromarray(predicted_mask).save(
            output_dir / file_name, format=file_ext.upper()
        )
    print(f"Data saved in {output_dir}.")
    return
