# U-Net and Associated Functions

## Overview
This notebook contains the  U-Net and it's associated functions. All functions contain doc strings which define the function inputs and outputs, and their data types. There are also comments throughout each function for additional clarity.

## Structure
- **[Import Necessary Packages](#import-necessary-packages)**: Importing the necessary packages used by functions in this notebook.
- **[Internal Functions](#internal-functions)**: Internal or private functions that are only called by other functions and should never be called directly by a script. 
    - [`_get_colors`](#get-colors): 
    - [`_get_class_labels`](#get-class-labels): 
    - [`_get_color_masks`](#get-colored-masks): 
    - [`_compute_basic_metrics`](#compute-basic-metrics):
    - [`_view_predictions`](#view-predictions):
- **[Public Functions](#public-functions)**: Public functions to be imported and called by other scripts, outside of this notebook. 
    - [`iou_metric`](#intersection-over-union): 
    - [`dice_coeff`](#dice-coefficient): 
    - [`unet_model`](#u-net-model-definition): 
    - [`plot_losses`](#plot-model-losses): 
    - [`test_unet`](#test-trained-u-net): 
    - [`save_predicted_masks`](#save-predicted-masks): 


## Usage
This notebook is not intended to be run individually. The public functions in this notebook should be imported into the `main.ipynb` using the `import-ipynb` package. The purpose of this notebook is to clearly separate and define each U-Net-related function that is called internally (in this notebook by one of the other functions) and each U-Net-related function that is called in `main.ipynb`.

## Import Necessary Packages

We first import the necessary packages used by the functions in this notebook.

In [None]:
# Necessary imports
import os
from pathlib import Path
import warnings
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as mpatches
import numpy as np
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score
from typing import Optional, Dict, List, Tuple, Any, Union

# Optional line to suppress unnecessary tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from tensorflow.keras import layers, models, Model, backend

## Internal Functions

This section contains any internal, or private functions. These functions are ones that are called by other functions, and are not intended for use in any other manner. 

### Get Colors
The `_get_colors` function is an internal function that generates a dictionary of colors for each class. It uses Matplotlib colormaps to create distinct colors for visualization purposes. This function is called in [`_get_color_masks`](#get-colored-masks), and the resulting dictionary (`class_colors`) is used in [`_get_color_masks`](#get-colored-masks). That `class_colors` dicionary is combined with the  `class_labels` dictionary to form a single `class_colors_and_labels` dictionary output from [`_get_color_masks`](#get-colored-masks). In turn this is passed to [`_view_predictions`](#view-predictions) for use in the figures. 


#### _get_colors
```
def _get_colors(num_colors: int, color_map: str) -> Dict[int, np.ndarray]:
```
#### Description

Generates a dictionary of colors for each class using Matplotlib colormaps.

#### Parameters
- `num_colors` (int): 

    The number of colors to generate.

- `color_map` (str): 

    Matplotlib colormap.

#### Returns

- Dict[int, np.ndarray]: 
    A dictionary where the key is the class index and the value is a numpy array representing the RGB colors.

#### Raises
- `UserWarning`: If the specified colormap is not recognized, it will warn the user and use the default 'viridis' colormap.

#### Example
```
import numpy as np
import matplotlib.cm as cm

# Generate colors for 5 classes using the 'plasma' colormap
class_colors = _get_colors(num_colors=5, color_map='plasma')
print(class_colors)
```
#### Notes
- The function selects the colormap from Matplotlib's library and extracts only the RGB color information.
- If the specified colormap is not recognized, it defaults to using the 'viridis' colormap and warns the user.
- The RGB values are extracted from the colormap and stored in a dictionary where each class index maps to its corresponding color.


In [None]:
# Internal function to get colors for the classes for display
def _get_colors(num_colors: int, color_map: str) -> Dict[int, np.ndarray]:
    """
    Generates a dictionary of colors for each class.

    Args:
        num_colors (int): The number of colors to generate.
        color_map (str): Matplotlib colormap.

    Returns:
        Dict[int, np.ndarray]: A dictionary where the key is the class index and the value is a numpy array representing the RGB colors.
    """
    # Select the color map from matplotlib's library, and keep only the RGB color information
    if color_map:
        try:
            colors = cm.get_cmap(color_map, num_colors)
        # If the color map is not recognized in matplotlib's color map, warn the user and use the default 
        except ValueError:
            warnings.warn(
                f"Unrecognized color map {color_map}.\n" "Will use default 'viridis'.\n",
                UserWarning,
            )
            colors = cm.get_cmap("viridis", num_colors)

    # If no colors are given, use the default
    else:
        colors = cm.get_cmap("viridis", num_colors)

    # Extract the RGB values from the colormap
    colors = colors(np.arange(num_colors))[:, :3]

    # Create a dictionary of class colors
    class_colors = {idx: colors[idx] for idx in range(num_colors)}
    return class_colors



### Get Class Labels
The `_get_class_labels` function is an internal function that ensures the class labels dictionary is correctly set up. If class_labels is `None`, it generates a dictionary of class labels. These labels are then used when plotting the true masks and the predicted masks. These labels are also used when printing the training metrics for each class. This function is called in [`test_unet`](#test-trained-u-net), and the resulting dictionary (`class_labels`) is used in [`_get_color_masks`](#get-colored-masks). That `class_labels` dicionary is combined with the  `class_colors` dictionary to form a single `class_colors_and_labels` dictionary output from [`_get_color_masks`](#get-colored-masks). In turn this is passed to [`_view_predictions`](#view-predictions) for use in the figures. 

#### _get_class_labels
```
def _get_class_labels(
    class_labels: Union[Dict[int, str], None], num_classes: int
) -> Dict[int, str]:
```
#### Description

Generates a dictionary of colors for each class using Matplotlib colormaps.

#### Parameters
- `class_labels` (Optional[Dict[int, str]]): 

    An optional dictionary of class labels.

- `num_classes` (int): 

    The number of classes.

#### Returns

- Dict[int, str]: 

    A dictionary where the key is the class index and the value is the class label as a string.

#### Example
```
# Example class labels provided by the user
user_class_labels = {0: "Background", 1: "Class1"}

# Generate class labels for 2 classes
class_labels = _get_class_labels(class_labels=user_class_labels, num_classes=2)
print(class_labels)
```
#### Notes
- If `class_labels` is not provided, the function generates default labels based on the number of classes.
- Handles the case where there is only one class label but two classes (binary classification) by assigning a background label.
- If the number of class labels does not match the number of classes, it resets the labels to default index numbers.


In [None]:
# Internal function to get or assign class labels
def _get_class_labels(
    class_labels: Union[Dict[int, str], None], num_classes: int
) -> Dict[int, str]:
    """
    Ensures the class labels dictionary is correctly set up. If class_labels is None,
    generates a dictionary of class labels.

    Args:
        class_labels (Optional[Dict[int, str]]): An optional dictionary of class labels.
        num_classes (int): The number of classes.

    Returns:
        Dict[int, str]: A dictionary where the key is the class index and the value is the class label as a string.
    """
    # If class_labels is not provided, generate default labels based on the number of classes
    if not class_labels:
        num_labels = 2 if num_classes == 1 else num_classes
        class_labels = {idx: str(idx) for idx in range(num_labels)}

    # Handle the case where there is only one class label but two classes (binary classification)
    elif len(class_labels) == 1 and num_classes == 2:
        key, value = next(iter(class_labels.items()))
        background_key = 1 if key == 0 else 0
        class_labels[background_key] = "Assigned Background"

    # If the number of class labels does not match the number of classes, reset the labels
    elif len(class_labels) != num_classes and not (
        len(class_labels) == 2 and num_classes == 1
    ):
        print(
            f"Number of class labels ({len(class_labels)}) does not match the number of classes ({num_classes}). Labels will be changed to class index numbers."
        )
        num_labels = 2 if num_classes == 1 else num_classes
        class_labels = {idx: str(idx) for idx in range(num_labels)}

    return class_labels

### Get Colored Masks
The `_get_color_masks` function is an internal function that applies the colors from [`_get_colors`] to the true and predicted masks. This function also combines the `class_labels` dicionary from [`_get_class_labels`](#get-class-labels) with the  `class_colors` dictionary from [`_get_colors`](#get-colors) to form a single `class_colors_and_labels` dictionary. In turn this dictionary is used by [`_view_predictions`](#view-predictions) in the labels for the figures. 

#### _get_color_masks
```
def _get_color_masks(
    num_classes: int,
    masks: np.ndarray,
    predicted_masks: np.ndarray,
    class_labels: Dict[int, str],
    class_color_map: str,
) -> Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]:
```
#### Description

Generates a dictionary of colors for each class using Matplotlib colormaps.

#### Parameters
- `num_classes` (int): 

    The number of classes.

- `masks` (np.ndarray): 

    The ground truth masks.

- `predicted_masks` (np.ndarray): 

    The predicted masks.

- `class_labels` (Dict[int, str]): 

    A dictionary where the key is the class index and the value is a class label.

- `class_color_map` (str): 

    Matplotlib colormap.

    
#### Returns

- Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]: 

    A tuple containing:
    - `colored_masks` (np.ndarray): 
    
        The colored ground truth masks.

    - `colored_preds` (np.ndarray): 
    
        The colored predicted masks.

    - `class_colors_and_labels` (Dict[int, Dict[str, Union[str, np.ndarray]]]): 
    
        A dictionary where the key is the class index and each class index contains a color and label.

#### Example
```
import numpy as np

# Example ground truth masks and predicted masks
masks = np.random.randint(0, 3, (10, 256, 256))
predicted_masks = np.random.randint(0, 3, (10, 256, 256))

# Example class labels
class_labels = {0: "Background", 1: "Class1", 2: "Class2"}

# Color the masks and predicted masks using the 'viridis' colormap
colored_masks, colored_preds, class_colors_and_labels = _get_color_masks(
    num_classes=3, masks=masks, predicted_masks=predicted_masks,
    class_labels=class_labels, class_color_map='viridis'
)
print(colored_masks.shape, colored_preds.shape)
print(class_colors_and_labels)
```
#### Notes
- Determines the number of colors needed based on the number of classes. If num_classes is 1, it uses 2 colors (one for background and one for foreground).
- Uses the `_get_colors` function to get the colors for each class using the specified colormap.
- Creates a dictionary to store the color and label for each class.
- Applies the color map to the ground truth masks and predicted masks to generate colored masks.

In [None]:
# Internal function to color the masks and the predicted masks based on the color map and class labels
def _get_color_masks(
    num_classes: int,
    masks: np.ndarray,
    predicted_masks: np.ndarray,
    class_labels: Dict[int, str],
    class_color_map: str,
) -> Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]:
    """
    Colors the masks and predicted masks for visualization.

    Args:
        num_classes (int): The number of classes.
        masks (np.ndarray): The true masks.
        predicted_masks (np.ndarray): The predicted masks.
        class_labels (Dict[int, str]): A dictionary where the key is the class index (int) and the value is a class label.
        class_color_map (str): Matplotlib colormap.

    Returns:
        Tuple[np.ndarray, np.ndarray, Dict[int, Dict[str, Union[str, np.ndarray]]]]: A tuple containing:
            - colored_masks (np.ndarray): The colored true masks.
            - colored_preds (np.ndarray): The colored predicted masks.
            - class_colors_and_labels (Dict[int, Dict[str, Union[str, np.ndarray]]]): A dictionary where the key is the class index and each class index contains a color and label.
    """
    # Determine the number of colors needed based on the number of classes
    if num_classes == 1:
        num_colors = 2 # One color for background, one color for foreground
    else:
        num_colors = num_classes

    # Get the colors for each class using the specified colormap
    class_colors = _get_colors(num_colors, class_color_map)

    # Create a dictionary to store the color and label for each class
    class_colors_and_labels = {}
    for key in class_colors:
        class_colors_and_labels[key] = {
            "color": class_colors[key],
            "label": class_labels[key],
        }

    # Create a color map that can easily be applied to the masks and predictions
    color_map = np.array(
        [class_colors.get(color, [0, 0, 0]) for color in range(num_colors)],
        dtype=np.float32,
    )

    # Apply the color map to the true masks and predicted masks
    colored_masks = color_map[masks]
    colored_preds = color_map[predicted_masks]

    return colored_masks, colored_preds, class_colors_and_labels



### Compute Basic Metrics
The `_compute_basic_metrics` function is an internal function that computes basic metrics (precision, recall, and F1 score) for the given masks and predicted masks. It uses the scikit-learn library to calculate these metrics. The chosen metrics give basic information on the performance of our trained U-Net. This function is called in [`test_unet`](#test-trained-u-net). 

Precision: Calculates the precision, which is the ratio of correctly predicted positive observations to the total predicted positives. It is a measure of the accuracy of the positive predictions.

$$ \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} $$

Recall: Calculates the recall, also known as sensitivity or true positive rate, which is the ratio of correctly predicted positive observations to all the actual positives. It is a measure of how well the model can capture positive instances.

$$ \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} $$

F1 Score: Calculates the F1 score, which is the weighted average of precision and recall. It considers both false positives and false negatives and is useful when you need a balance between precision and recall.

$$ \text{F1 Score} = 2 \frac{(\text{Precision}) (\text{Recall})}{\text{Precision} + \text{Recall}} $$

- TP - True Positive
- FP - False Positive
- FN - False Negative

#### _compute_basic_metrics
```
def _compute_basic_metrics(
    masks: np.ndarray, predicted_masks: np.ndarray, metric_ave_method: str
) -> Dict[str, float]:
```
#### Description

Computes basic metrics (precision, recall, and F1 score) for the given masks and predicted masks.

#### Parameters

- `masks` (np.ndarray): 

    The true masks.

- `predicted_masks` (np.ndarray): 

    The predicted masks.

- `metric_ave_method` (str): 

    Averaging method to use in the metrics.

#### Returns

- Dict[str, float]: 

    A dictionary containing the precision, recall, and F1 score.

#### Example
```
# Example true masks and predicted masks
masks = np.random.randint(0, 2, (10, 256, 256))
predicted_masks = np.random.randint(0, 2, (10, 256, 256))

# Compute basic metrics with 'binary' averaging method
metrics = _compute_basic_metrics(masks=masks, predicted_masks=predicted_masks, metric_ave_method='binary')
print(metrics)
```
#### Notes
- Flattens the masks and predictions as necessary for the scikit-learn library metrics.
- Calculates the precision, which is the ratio of correctly predicted positive observations to the total predicted positives.
- Calculates the recall, also known as sensitivity or true positive rate, which is the ratio of correctly predicted positive observations to all the actual positives.
- Calculates the F1 score, which is the weighted average of precision and recall. It considers both false positives and false negatives and is useful when a balance between precision and recall is needed.

In [None]:
# Internal function to compute some basic metrics that asses model performance on the test data. 
def _compute_basic_metrics(
    masks: np.ndarray, predicted_masks: np.ndarray, metric_ave_method: str
) -> Dict[str, float]:
    """
    Computes basic metrics (precision, recall, and F1 score) for the given masks and predicted masks.

    Args:
        masks (np.ndarray): The true masks.
        predicted_masks (np.ndarray): The predicted masks.
        metric_ave_method (str): Averaging method to use in the metrics.

    Returns:
        Dict[str, float]: A dictionary containing the precision, recall, and F1 score.
    """
    # Flatten the masks and predictions as necessary for the scikit-learn library metrics
    masks = masks.flatten()
    predicted_masks = predicted_masks.flatten()
    # Precision: Calculates the precision, which is the ratio of correctly predicted positive observations to the total predicted positives. It is a measure of the accuracy of the positive predictions.
    # Precision = true_positives / (true_positives + false_positives)
    precision = precision_score(
        masks, predicted_masks, zero_division=0, average=metric_ave_method
    )
    # Recall: Calculates the recall, also known as sensitivity or true positive rate, which is the ratio of correctly predicted positive observations to all the actual positives. It is a measure of how well the model can capture positive instances.
    # Recall = true_positives / (true_positives + false_negatives)
    recall = recall_score(
        masks, predicted_masks, zero_division=0, average=metric_ave_method
    )
    # F1 Score: Calculates the F1 score, which is the weighted average of precision and recall. It considers both false positives and false negatives and is useful when you need a balance between precision and recall.
    # F1 Score = 2 * (Precision * Recall) / (Precision + Recall)
    f1 = f1_score(masks, predicted_masks, zero_division=0, average=metric_ave_method)
    return {"precision": precision, "recall": recall, "f1 score": f1}



### View Predictions

The `_view_predictions` function is an internal function that displays the images, masks, and predicted masks, to visually compare the predictions with their true values. 

#### _view_predictions
```
def _view_predictions(
    images: np.ndarray,
    colored_masks: np.ndarray,
    colored_preds: np.ndarray,
    image_names: List[str],
    class_colors_and_labels: Dict[int, Dict[str, Union[str, np.ndarray]]],
    display_count: int,
    images_per_figure: int,
) -> None:
```
#### Description

Displays the images, their masks, and the model predictions of the masks..

#### Parameters

- `images` (np.ndarray): 

    The input images.

- `colored_masks` (np.ndarray): 

    The colored true masks.

- `colored_preds` (np.ndarray): 

    The colored predicted masks.

- `image_names` (List[str]): 

    List of image names.

- `class_colors_and_labels` (Dict[int, Dict[str, Union[str, np.ndarray]]]): 

    A dictionary where the key is the class index and each class index contains a color and label.

- `display_count` (int): 

    The number of images to display.
- `images_per_figure` (int): 

    The number of images per figure.

#### Example
```
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Example images, masks, and predicted masks
images = np.random.rand(10, 256, 256, 3)
colored_masks = np.random.rand(10, 256, 256, 3)
colored_preds = np.random.rand(10, 256, 256, 3)

# Example image names and class labels
image_names = [f"Image {i}" for i in range(10)]
class_colors_and_labels = {
    0: {"color": [1, 0, 0], "label": "Background"},
    1: {"color": [0, 1, 0], "label": "Class1"},
    2: {"color": [0, 0, 1], "label": "Class2"},
}

# View predictions
_view_predictions(
    images=images,
    colored_masks=colored_masks,
    colored_preds=colored_preds,
    image_names=image_names,
    class_colors_and_labels=class_colors_and_labels,
    display_count=5,
    images_per_figure=2,
)
```
#### Notes
- Determines the number of images to display based on `display_count` and the length of images.
- If `image_names` are not provided, it generates names based on the number of images.
- Creates legend patches for each mask class using their colors and labels.
- Calculates the number of figures needed and organizes the display into subplots.
- Plots the images, true masks, and predicted masks in the subplots.
- Adds a legend to the figure using the color labels and adjusts the layout for display.


In [None]:
# Internal function to view the predictions based on the test data
def _view_predictions(
    images: np.ndarray,
    colored_masks: np.ndarray,
    colored_preds: np.ndarray,
    image_names: List[str],
    class_colors_and_labels: Dict[int, Dict[str, Union[str, np.ndarray]]],
    display_count: int,
    images_per_figure: int,
) -> None:
    """
    Displays the images, their masks, and the model predictions of the masks.

    Args:
        images (np.ndarray): The input images.
        colored_masks (np.ndarray): The colored true masks.
        colored_preds (np.ndarray): The colored predicted masks.
        image_names (List[str]): List of image names.
        class_colors_and_labels (Dict[int, Dict[str, Union[str, np.ndarray]]]): A dictionary where the key is the class index and each class index contains a color and label.
        display_count (int): The number of images to display.
        images_per_figure (int): The number of images per figure.

    Returns:
        None: This function does not return anything. It displays the figures.
    """
    # Determine the number of images to display
    img_cnt = min(display_count, len(images))
    
    # If image names are not provided, create names based on the number of images
    if image_names is None:
        image_names = [str(idx) for idx in range(len(images))]

    # Create legend patches (labels) for each mask class
    patches = [
        mpatches.Patch(
            color=class_colors_and_labels[idx]["color"],
            label=class_colors_and_labels[idx]["label"],
        )
        for idx in class_colors_and_labels
    ]

    # Calculate the number of figures needed
    num_figs = (img_cnt + images_per_figure - 1) // images_per_figure

    # Loop over each figure
    for fig_idx in range(num_figs):

        # Determine the number of images in the current figure
        img_in_fig = min(images_per_figure, img_cnt - fig_idx * images_per_figure)

        # Create subplots for the current figure
        fig, axes = plt.subplots(img_in_fig, 3, figsize=(10, 3 * img_in_fig))

        # Ensure axes is a 2D array even if there's only one row
        if img_in_fig == 1:
            axes = np.expand_dims(axes, axis=0)

        # Loop over each image to be included in the current figure
        for idx in range(img_in_fig):
            img_idx = (fig_idx * images_per_figure) + idx
            if img_idx >= img_cnt:
                break
            
            # Extract the image, mask, and predicted mask
            image, mask, prediction = (
                images[img_idx],
                colored_masks[img_idx],
                colored_preds[img_idx],
            )

            # Plot the image
            ax_img, ax_mask, ax_pred = axes[idx]
            ax_img.imshow(image)
            ax_img.axis("off")
            ax_img.set_title(f"Image {image_names[img_idx]}")

            # Plot the true mask
            ax_mask.imshow(mask)
            ax_mask.axis("off")
            ax_mask.set_title(f"Mask {image_names[img_idx]}")
            
            # Plot the predicted mask
            ax_pred.imshow(prediction)
            ax_pred.axis("off")
            ax_pred.set_title(f"Pred. Mask {image_names[img_idx]}")

        # Add legend to the figure using the color labels
        fig.legend(
            handles=patches,
            loc="upper right",
            bbox_to_anchor=(1, 1),
            bbox_transform=fig.transFigure,
        )
        # Adjust layout and display the figure without blocking execution
        plt.tight_layout()
        plt.show(block=False)
    return



## Public Functions

The functions in this section are public functions which can be called or passed outside this notebook, and are intended for use in the `main.ipynb`. 

### Intersection Over Union 
The `iou_metric` function is technically a public function, that is passed in `main.ipynb` to the [`unet_model`](#u-net-model-definition) as one of our custom `metrics` and to the Tensorflow Keras [`load_model`](#https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model) as one of our `custom_objects`. Note, we never actually **call** the function, it is only ever **passed**. This is because the function itself is being implemented in the model itself. 

Intersection Over Union (IOU), also known as the [Jaccard Index](https://en.wikipedia.org/wiki/Jaccard_index), is a commonly used metric in image segmentation. It measures how well a predicted segmentation matches the true segmentation. It's calculated by dividing the area where the predicted and true overlap (the intersection) by the area covered by either the predicted or true (union). A perfect match would result in an IOU of 1. 
$$ IOU = \frac{\text{TP}}{\text{TP} + \text{FP} + \text{FN} } = \frac{|\text{pred} \cap \text{true}|}{|\text{pred} \cup \text{true}|}$$

- TP - True Positive
- FP - False Positive
- FN - False Negative

#### iou_metric
```
def iou_metric(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
```
#### Description

Calculates the Intersection over Union (IoU) metric for the given true and predicted masks.

#### Parameters

- `true_masks` (np.ndarray): 

    The true masks.

- `predicted_masks` (np.ndarray): 

    The predicted masks.

- `smooth` (float, optional): 

    A small value to avoid division by zero. Default is 1e-8

#### Returns

- `float`: 

    The IoU metric.

#### Example
```
import numpy as np
import tensorflow.keras.backend as backend

# Example true masks and predicted masks
true_masks = np.random.randint(0, 2, (10, 256, 256))
predicted_masks = np.random.randint(0, 2, (10, 256, 256))

# Calculate IoU metric
iou = iou_metric(true_masks=true_masks, predicted_masks=predicted_masks)
print(f"IoU: {iou}")
```
#### Notes
- Flattens the true and predicted masks to 1D arrays for calculation.
- Calculates the intersection of the true and predicted masks.
- Calculates the union of the true and predicted masks.
- Adds a small smooth value to avoid division by zero.
- Returns the IoU metric, which is the ratio of the intersection to the union of the true and predicted masks.

In [None]:
# Custom metric used when training the model. Intersection Over Union (IOU)
def iou_metric(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
    """
    Calculates the Intersection over Union (IoU) metric for the given true and predicted masks.

    Args:
        true_masks (np.ndarray): The true masks.
        predicted_masks (np.ndarray): The predicted masks.
        smooth (float, optional): A small value to avoid division by zero. Default is 1e-8.

    Returns:
        float: The IoU metric.
    """
    # Flatten the true and predicted masks to 1D arrays
    true_masks = backend.flatten(true_masks)
    predicted_masks = backend.flatten(predicted_masks)

    # Calculate the intersection of the true and predicted masks
    intersection = backend.sum(true_masks * predicted_masks)

    # Calculate the union of the true and predicted masks
    union = backend.sum(true_masks) + backend.sum(predicted_masks) - intersection

    # Calculate the Intersection Over Union (IOU)
    iou = (intersection + smooth) / (union + smooth)
    return iou


### Dice Coefficient
The `dice_coeff` function is technically a public function, that is passed in `main.ipynb` to the [`unet_model`](#u-net-model-definition) as one of our custom `metrics` and to the Tensorflow Keras [`load_model`](#https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model) as one of our `custom_objects`. Note, we never actually **call** the function, it is only ever **passed**. This is because the function itself is being implemented in the model itself. 

The Dice Coefficient, also known as the [Dice-Sorensen Coefficient](#https://en.wikipedia.org/wiki/Dice-S%C3%B8rensen_coefficient), is a commonly used metric in image segmentation. It measures how well a predicted segmentation matches the true segmentation, similarly to [IOU](#intersection-over-union), but it puts more emphasis on the overlap. It's calculated by dividing the area where the predicted and true overlap (the intersection) by the area covered by the total number of pixels in both the predicted and true. This is then multiplied by a factor of two, effectively weighting the intersection. A perfect match would result in a Dice coefficient of 1. 
$$ Dice = \frac{2 \text{TP}}{(2 \text{TP}) + \text{FP} + \text{FN} } = \frac{2 |\text{pred} \cap \text{true}|}{|\text{pred}| + |\text{true}|}$$

- TP - True Positive
- FP - False Positive
- FN - False Negative

#### dice_coeff
```
def dice_coeff(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
```
#### Description

Calculates the Dice Coefficient for the given true and predicted masks.

#### Parameters

- `true_masks` (np.ndarray): 

    The true masks.

- `predicted_masks` (np.ndarray): 

    The predicted masks.

- `smooth` (float, optional): 

    A small value to avoid division by zero. Default is 1e-8

#### Returns

- `float`: 

    The Dice coefficient.

#### Example
```
import numpy as np
import tensorflow.keras.backend as backend

# Example true masks and predicted masks
true_masks = np.random.randint(0, 2, (10, 256, 256))
predicted_masks = np.random.randint(0, 2, (10, 256, 256))

# Calculate Dice Coefficient metric
dice = dice_coeff(true_masks=true_masks, predicted_masks=predicted_masks)
print(f"Dice Coefficient: {dice}")
```
#### Notes
- Flattens the true and predicted masks to 1D arrays for calculation.
- Calculates the intersection of the true and predicted masks.
- Adds a small smooth value to avoid division by zero.
- Returns the Dice coefficient.


In [None]:
# Custom metric used when training the model. Dice Coefficient.
def dice_coeff(
    true_masks: np.ndarray, predicted_masks: np.ndarray, smooth: float = 1e-8
) -> float:
    """
    Calculates the Dice Coefficient for the given true and predicted masks.

    Args:
        true_masks (np.ndarray): The true masks.
        predicted_masks (np.ndarray): The predicted masks.
        smooth (float, optional): A small value to avoid division by zero. Default is 1e-8.

    Returns:
        float: The Dice Coefficient.
    """
    
    # Flatten the true and predicted masks to 1D arrays
    true_masks = backend.flatten(true_masks)
    predicted_masks = backend.flatten(predicted_masks)

    # Calculate the intersection of the true and predicted masks
    intersection = backend.sum(true_masks * predicted_masks)

    # Calculate the Dice Coefficient
    dice = (2.0 * intersection + smooth) / (
        backend.sum(true_masks) + backend.sum(predicted_masks) + smooth
    )
    return dice



### U-Net Model Definition
The `unet_model` is a public function that defines and builds the U-Net as a Tensorflow Keras [`Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model), specifically using their [Functional API](https://www.tensorflow.org/guide/keras/functional_api). 

The U-Net is a well established type of convolution neural network (CNN) used in image segmentation.

<figure id="unet-example">  
    <img src="resources/Example_architecture_of_U-Net_for_producing_k_256-by-256_image_masks_for_a_256-by-256_RGB_image.png"
    alt="example architecture of U-Net for producing k 256-by-256 image masks for a 256-by-256 RGB image"
    width="500"/>
    <figcaption>Example U-Net Architecture </figcaption>
</figure> 

The [example above](#unet-example) shows the general architecture of a U-Net, applied to an RGB image (3 channels) of size 256x256 pixels being encoded down in the unet, passed through a bottleneck layer (Down | conv4) and then being decoded back up into a mask with `k` number of classes that is size 256x256 pixels. 

#### unet_model
```
def unet_model(
    input_shape: Tuple[int, int, int],
    num_classes: int,
    num_blocks: int = 4,
    optimizer: str = "adam",
    metrics: List[str] = ["accuracy"],
):
```
#### Description

Builds a U-Net model for image segmentation.

#### Parameters

- `input_shape` (Tuple[int, int, int]): 

    The shape of the input images (height, width, channels).

- `num_classes` (int): 

    The number of output classes for segmentation.

- `num_blocks` (int, optional): 

    The number of encode/decode blocks in the U-Net. Default is 4.

- `optimizer` (str, optional): 

    Optimization method to use for the model. The default is "adam".

- `metrics` (List[str], optional): 

    Optional metrics used in model training improvements. Default is ["accuracy"] .

#### Returns

- `Model`: 

    A TensorFlow Keras Model representing the U-Net.

#### Example
```
from tensorflow.keras import layers, models

# Define the input shape and number of classes
input_shape = (256, 256, 3)
num_classes = 3

# Build the U-Net model
model = unet_model(input_shape=input_shape, num_classes=num_classes, num_blocks=4, optimizer="adam", metrics=["accuracy"])

# Display the model summary
model.summary()
```
#### Notes
- Creates input layer using the specified `input_shape`.
- Builds encoder blocks with 2 ReLU convolution layers and 1 MaxPooling layer, storing intermediate results for skip connections.
- Includes a bottleneck block with 2 ReLU convolution layers.
- Builds decoder blocks with 1 UpSampling layer, concatenation with the corresponding encoder layer, and 2 ReLU convolution layers.
- Selects the appropriate activation and loss functions based on the number of classes.
- Creates the output layer with the appropriate activation function.
- Compiles the model with the specified optimizer, loss function, and metrics.



In [None]:
# Function ot define teh actual U-Net model
def unet_model(
    input_shape: Tuple[int, int, int],
    num_classes: int,
    num_blocks: int = 4,
    optimizer: str = "adam",
    metrics: List[str] = ["accuracy"],
):
    """
    Builds a U-Net model for image segmentation.

    Args:
        input_shape (Tuple[int, int, int]): The shape of the input images (height, width, channels).
        num_classes (int): The number of output classes for segmentation.
        num_blocks (int, optional): The number of encode/decode blocks in the U-Net. Default is 4.
        optimizer (str, optional): Optimization method to use for the model. The default is "adam".
        metrics (List[str], optional): Optional metrics used in model training improvements. Default is ["accuracy"] .

    Returns:
        Model: A TensorFlow Keras Model representing the U-Net.
    """
    # Create input layer
    inputs = layers.Input(shape=input_shape)

    # Initialize components for subsequent encoder and decoder layers
    encoder_layers = []
    x = inputs  # updates at every layer
    filters = 64  # Initial number of filters, updates at every layer block
    
    # For each encoder block create 3 layers: 2 ReLU convolution layers, and 1 MaxPooling layer. Also store the result from te 2 convolution layers for skipping when decoding.
    for idx in range(num_blocks):
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        encoder_layers.append(x)  # Store the result for skip connections
        x = layers.MaxPooling2D(2)(x)
        filters *= 2 # Double the number of filters for the next block

    # Bottleneck block: 2 ReLU convolution layers
    x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
    x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)

    # For each decoder block create 4 layers: 1 UpSampling, 1 Concatentation with the skipped encoder layer from the same level, and 2 ReLU convolution layers.
    for idx in reversed(range(num_blocks)):
        filters //= 2 # Halve the number of filters for the next block
        x = layers.UpSampling2D(size=2)(x)
        x = layers.concatenate([x, encoder_layers[idx]], axis=-1) # Concatenate with the corresponding encoder layer
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
        x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)

    # Select the appropriate activation and loss functions based on the number of classes in the input data.
    if num_classes == 1:
        activation_fn = "sigmoid"
        loss_fn = "binary_crossentropy"
    else:
        activation_fn = "softmax"
        loss_fn = "categorical_crossentropy"

    # Output layer
    outputs = layers.Conv2D(num_classes, 1, activation=activation_fn)(x)

    # Create the model as tensorflow NN Model.
    unet_model = models.Model(inputs, outputs)

    # Compile the model with the desired optimizer, loss function, and metrics
    unet_model.compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)

    return unet_model



### Plot Model Losses
The `plot_losses` is a public function that plots the training and validation losses over the epochs from the training process.

#### plot_losses
```
def plot_losses(model_fit: Any) -> None:

```
#### Description

Plots the training and validation losses from the model fitting process.

#### Parameters

- `model_fit` (Any): 

    The result of the model fitting process (e.g., the history object returned by `model.fit`). 

#### Example
```
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Example model and training process
model = Sequential([Dense(10, activation='relu', input_shape=(100,)), Dense(1, activation='sigmoid')])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(np.random.rand(1000, 100), np.random.randint(0, 2, 1000), epochs=10, validation_split=0.2)

# Plot the losses
plot_losses(model_fit=history)
```
#### Notes
- Plots the training loss using `model_fit.history["loss"]`.
- Plots the validation loss using `model_fit.history["val_loss"]`.


In [None]:
# Function ot plot the losses from the training process of the U-Net model
def plot_losses(model_fit: Any) -> None:
    """
    Plots the training and validation losses from the model fitting process.

    Args:
        model_fit (Any): The result of the model fitting process (e.g., the history object returned by `model.fit`).

    Returns:
        None: This function does not return anything. It plots the training and validation losses.
    """
    # Create a new figure with specified size
    plt.figure(figsize=(6, 4))

    # Plot the training loss
    plt.plot(model_fit.history["loss"], label="Train Loss")
    
    # Plot the validation loss
    plt.plot(model_fit.history["val_loss"], label="Validation Loss")

    # Set axes labels and limits
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.xlim(left=0)
    plt.ylim(bottom=0)

    # Set the title 
    plt.title("Training Loss vs. Validation Loss")

    # Add legend
    plt.legend()

    # Adjust fit for neatest, and don't block execution with the plot 
    plt.tight_layout()
    plt.show(block=False)
    return


### Test Trained U-Net
The `test_unet` is a public function that predicts the masks for a given set of input images, using the trained U-Net model. It then prints the metrics from [`_compute_basic_metrics`](#compute-basic-metrics), to compare teh model's predictions iwth the true masks. If `display_figures` is set to `True` then it will also display figures showing images, masks, and predicted masks together for a visual comparison. It will return the predicted masks so that they can be saved using [`save_predicted_masks`](#save-predicted-masks).

#### test_unet
```
def test_unet(
    model: Model,
    images: np.ndarray,
    masks: np.ndarray,
    num_classes: int,
    threshold: float = 0.5,
    image_names: Optional[List[str]] = None,
    class_labels: Optional[Dict[int, str]] = None,
    class_color_map: Optional[str] = None,
    display_figures: bool = True,
    display_count: int = 3,
    images_per_figure: int = 3,
) -> None:
```
#### Description

Tests the U-Net model by using the model to predict a set of masks for a given set of images and then comparing these predictions with the true masks. 

#### Parameters

- `model` (Model): 

    The U-Net model to be tested.

- `images` (np.ndarray): 

    The input images.

- `masks` (np.ndarray): 

    The true masks.

- `num_classes` (int): 

    The number of classes.

- `threshold` (float, optional): 

    Threshold to  use if the number of classes is 1 (binary). Default is `0.5`.

- `image_names` (Optional[List[str]], optional): 

    List of image names. Default is `None`.

- `class_labels` (Optional[Dict[int, str]], optional): 

    A dictionary where the key is the class index and the value is a the class label. Default is `None`.

- `display_figures` (bool, optional): 

    Whether to display the figures. Default is `True`.

- `class_color_map` (Optional[str]): 

    Matplotlib colormap. Default is `'viridis'`.

- `display_count` (int, optional): 

    The number of images to display. Default is `3`.

- `images_per_figure` (int, optional): 

    The number of images per figure. Default is `3`.

#### Returns

- np.ndarray:

    The predicted masks.

#### Example
```
import numpy as np
from tensorflow.keras.models import Model

# Example U-Net model, images, and masks
model = unet_model(input_shape=(256, 256, 3), num_classes=3)
images = np.random.rand(10, 256, 256, 3)
masks = np.random.randint(0, 3, (10, 256, 256, 3))

# Test the U-Net model
predicted_masks = test_unet(
    model=model,
    images=images,
    masks=masks,
    num_classes=3,
    threshold=0.5,
    display_figures=True,
    display_count=3,
    images_per_figure=3
)
```

#### Notes
- Predicts masks using the U-Net model.
- Processes predicted masks based on the number of classes (binary or multi-class).
- Gets class labels, ensuring they are defined for potential figure legends.
- Computes metrics for the predicted masks and prints them.
- Displays figures if the `display_figures` flag is set to `True`.
- Uses `_get_color_masks` to get the labels and colored masks.
- Uses `_view_predictions` to display the figures showing the image, mask, and prediction.


In [None]:
# Function to test the U-Net model after training
def test_unet(
    model: Model,
    images: np.ndarray,
    masks: np.ndarray,
    num_classes: int,
    threshold: float = 0.5,
    image_names: Optional[List[str]] = None,
    class_labels: Optional[Dict[int, str]] = None,
    class_color_map: Optional[str] = None,
    display_figures: bool = True,
    display_count: int = 3,
    images_per_figure: int = 3,
) -> None:
    """
    Tests the U-Net model by plotting the predictions on a set of images and their corresponding masks.

    Args:
        model (Model): The U-Net model to be tested.
        images (np.ndarray): The input images.
        masks (np.ndarray): The true masks.
        num_classes (int): The number of classes.
        threshold (float, optional): Threshold to  use if the number of classes is 1 (binary). Default is 0.5.
        image_names (Optional[List[str]], optional): List of image names. Default is None.
        class_labels (Optional[Dict[int, str]], optional): A dictionary where the key is the class index and the value is a the class label. Default is None.
        display_figures (bool, optional): Whether to display the figures. Default is True.
        class_color_map (Optional[str]): Matplotlib colormap. Default is viridis.
        display_count (int, optional): The number of images to display. Default is 3.
        images_per_figure (int, optional): The number of images per figure. Default is 3.

    Returns:
        np.ndarray: The predicted masks.
    """

    # Predict masks using the U-Net model
    predicted_masks = model.predict(images)

    # Process predicted masks based on the number of classes
    if num_classes == 1:
        # Converting the predicted masks and masks back to the correct format for plotting and for performance metrics
        predicted_masks = (predicted_masks[..., 0] > threshold).astype(np.uint8)
        masks = (masks > threshold).astype(np.uint8)

        # Setting the method to be used in the metrics based on the number of classes
        metric_ave_method = "binary" 
    else:
        # Converting the predicted masks and masks back to the correct format for plotting and for performance metrics
        predicted_masks = np.argmax(predicted_masks, axis=-1).astype(np.uint8)
        masks = np.argmax(masks, axis=-1).astype(np.uint8)
        
        # Setting the method to be used in the metrics based on the number of classes
        metric_ave_method = None 

     # Get class labels, ensuring they are defined for potential figure legends
    class_labels = _get_class_labels(class_labels, num_classes)

    # Compute metrics for the predicted masks
    metrics = _compute_basic_metrics(masks, predicted_masks, metric_ave_method)
    print("Metrics based on test data:\n")
    
    # If multiple classes then print metrics for each class in the predicted masks
    if metric_ave_method != "binary":
        for idx in range(num_classes):

            # Verify that there are metrics for the given class index
            if 0 <= idx < len(metrics["precision"]):
                print(f"Class {class_labels[idx]} Individual Metrics:")
                print(f"Precision: {metrics['precision'][idx]}")
                print(f"Recall: {metrics['recall'][idx]}")
                print(f"F1 Score: {metrics['f1 score'][idx]}\n")
            else:
                print(f"Class {class_labels[idx]} Individual Metrics:")
                print(
                    f"Class {class_labels[idx]} not found in predicted masks. No metrics available."
                )

    # Else if only a background and foreground class (binary), print the metrics for the foreground
    else:
        print(f"Precision: {metrics['precision']}\n")
        print(f"Recall: {metrics['recall']}\n")
        print(f"F1 Score: {metrics['f1 score']}\n")

    # Display figures if the flag is set to True
    if display_figures:
        # Get the labels and colored masks
        colored_masks, colored_preds, class_colors_and_labels = _get_color_masks(
            num_classes, masks, predicted_masks, class_labels, class_color_map
        )

        # Display the figures that show the image, mask, and prediction
        _view_predictions(
            images,
            colored_masks,
            colored_preds,
            image_names,
            class_colors_and_labels,
            display_count,
            images_per_figure,
        )

    return predicted_masks

### Save Predicted Masks
The `save_predicted_masks` is a public function that saves the predicted masks from [`test_unet`](#test-trained-u-net) as any format accepted by the [PIL](https://pillow.readthedocs.io/en/stable/index.html) library's `Image.save()`. Images are saved in the format defined by the `file_ext` which is `"tif"` by default. 

#### save_predicted_masks
```
def save_predicted_masks(
    predicted_masks: np.ndarray,
    output_dir: str,
    mask_names: Optional[List[str]] = None,
    file_ext: str = "tif",
) -> None:
```
#### Description

Saves the predicted masks to the specified output directory.

#### Parameters

- `predicted_masks` (Dict[str, Any]): 

    Dictionary containing the predicted masks. Keys are mask names.

- `output_dir` (str): 

    Path to the directory where the masks will be saved.

- `mask_names` (Optional[List[str]], optional): 

    List of mask names to save. If `None`, all masks will be saved. Default is `None`.

- `file_ext` (str, optional): 

    File extension for the saved masks. Default is `"tif"`.

#### Example
```
import numpy as np
from pathlib import Path
from PIL import Image

# Example predicted masks
predicted_masks = np.random.randint(0, 256, (10, 256, 256), dtype=np.uint8)

# Save the predicted masks to the specified output directory
save_predicted_masks(
    predicted_masks=predicted_masks,
    output_dir="output_masks",
    mask_names=[f"mask_{i}" for i in range(10)],
    file_ext="tif"
)
```

#### Notes
- Creates the output directory if it doesn't exist.
- Normalizes the file extension to lowercase and handles `"tif"` as `"tiff"`, necessitated by the PIL library.
- If no mask names are provided, generates default names based on indices.
- Prints a message indicating the data has been saved.

In [None]:
def save_predicted_masks(
    predicted_masks: np.ndarray,
    output_dir: str,
    mask_names: Optional[List[str]] = None,
    file_ext: str = "tif",
) -> None:
    """
    Saves the predicted masks to the specified output directory.

    Args:
        predicted_masks (np.ndarray): Array containing the predicted masks.
        output_dir (str): Path to the directory where the masks will be saved.
        mask_names (Optional[List[str]], optional): List of mask names to save. If None, all masks will be saved. Default is None.
        file_ext (str, optional): File extension for the saved masks. Default is "tif".

    Returns:
        None: This function does not return anything.
    """
    # Create the output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Normalize file extension to lowercase and handle "tif" as "tiff"
    file_ext = file_ext.strip().lower()
    if file_ext == "tif":
        file_ext = "tiff"

    # If no mask names are provided, generate default names based on indices
    if not mask_names:
        mask_names = [str(idx) for idx in range(len(predicted_masks))]

    # Save each predicted mask with the corresponding name and file extension
    for predicted_mask, mask_name in zip(predicted_masks, mask_names):
        file_name = f"{mask_name}.{file_ext}"
        Image.fromarray(predicted_mask).save(
            output_dir / file_name, format=file_ext.upper()
        )
    print(f"Data saved in {output_dir}.")
    return