# Utility Functions

## Overview
This notebook contains utility functions used prior to the implementation of the U-Net. All functions contain doc strings which define the function inputs and outputs, and their data types. There are also comments throughout each function for additional clarity.

## Structure
- **[Import Necessary Packages](#import-necessary-packages)**: 
- **[Internal Functions](#internal-functions)**: Internal or private functions that are only called by other functions and should never be called directly by a script. 
    - `_display_samples`: Function to plot the given data, whether images or masks.
- **[Public Functions](#public-functions)**: Public functions to be imported and called by other scripts, outside of this notebook. 
    - `custom_warnings`: Function to format warnings in the desired manner
    - `load_images_and_masks`: Function to load images and masks from their given directories.
    - `remap_mask_classes`: Function to remap mask classes to a subset of the classes. 
    - `view_data`: Function to view images and/or masks. Calls `_display_samples`.
    - `preprocess_images_and_masks`: Function to preprocess and format the data for the U-Net.


## Usage
This notebook is not intended to be run individually. The public functions in this notebook should be imported into the `main.ipynb` using the `import-ipynb` package. The purpose of this notebook is to clearly separate and define each utility function that is called internally (in this notebook by one of the other functions) and each utility function that is called in `main.ipynb`.

## Import Necessary Packages

We first import the necessary packages used by the functions in this notebook. 

In [None]:
# Necessary imports
import os
import sys
import warnings
from pathlib import Path
import numpy as np
import re
from PIL import Image
import random
import matplotlib.pyplot as plt
from typing import Optional, Dict, List, Tuple, Any, TextIO

# Optional line to suppress unnecessary tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

## Internal Functions

This section contains any internal, or private functions. These functions are ones that are called by other functions, and are not intended for use in any other manner. 

### Display Samples

The function `_display_samples` is called in [View Data](#view-data). This function takes in a give set of samples, either images or masks, and then plots them in the desired format. This function is separated out as an internal function because it is called multiple times in [View Data](#view-data), and it is better practice to create a separate function for repeated code.

In [None]:
# Internal function used by view_data to actually display the figures for a given set of image or mask data. 
def _display_samples(
    samples: Dict[str, Any],
    sample_names: List[str],
    sample_type: int,
    max_cols: int,
    max_plots: int,
    colors: bool,
    scale: float = 1,
) -> None:
    """
    Displays the samples by plotting them.

    Args:
        samples (Dict[str, Any]): Dictionary containing the samples. Keys are sample names.
        sample_names (List[str]): List of sample names.
        sample_type (int): Type of sample, either 1 or 0.
        max_cols (int): Maximum number of columns in the plot.
        max_plots (int): Maximum number of plots to display.
        colors (bool): Whether to use colors in the plots.
        scale (float, optional): Scale factor for the plot size. Default is 1.

    Returns:
        None: This function does not return anything. It displays the plots.
    """

    # Determine the number of columns and rows needed for the plots
    num_cols = min(max_cols, max_plots)
    num_rows = (max_plots + max_cols - 1) // max_cols

    # Create a figure and a grid of subplots
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * scale, num_rows * scale)
    )
    axes = np.array(axes).reshape(-1) # Flatten the axes array for easier iteration

    # Choose the color map based on the 'colors' flag
    if colors:
        cmap = "viridis"
        if sample_type == 1:
            # Since the sample_type is 1 our samples are masks 
            fig.suptitle("Masks")

            # For each mask, scale the values based on the value range so that it may be plotted in color. Then plot the mask in the correct subplot
            for idx, sample in enumerate(sample_names):
                max_value = samples[sample].max() if samples[sample].max() else 1 # Set the max_value to 1 if it is 0, to prevent division by 0
                axes[idx].imshow(samples[sample] / max_value, cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")
        else:
            # Since the sample_type is not 1 our samples are images 
            fig.suptitle("Images")

            # For each image, plot the image in color, in the correct subplot
            for idx, sample in enumerate(sample_names):
                axes[idx].imshow(samples[sample], cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")

    else:
        cmap = "grey"
        # If the sample_type is 1 we have masks, otherwise we have images as our samples
        fig.suptitle("Masks") if sample_type == 1 else fig.suptitle("Images")

        # For each sample (image or mask), plot the sample in greyscale in the correct subplot.
        for idx, sample in enumerate(sample_names):
            axes[idx].imshow(samples[sample], cmap=cmap)
            axes[idx].set_title(f"{sample}")
            axes[idx].axis("off")

    # Remove any unused axes
    for empty_idx in range(idx + 1, len(axes)):
        fig.delaxes(axes[empty_idx])

    # Adjust layout to prevent overlap and display the plots
    plt.tight_layout()
    plt.show(block=False)

    return


## Public Functions

### Custom Warnings
The `custom_warnings` function is technically a public function, that is passed in `main.ipynb` to python's built in `warnings` library to tell it how we'd like our warnings displayed. However, we never actually **call** the function. **This function is deceptive, in it's use.** In `main.ipynb` the function itself is **passed** to `warnings` instead of **called**. It is also deceptive in that only the `category` and `message` inputs are explicitly used in our `custom_warnings` function. The remaining required inputs (`filename`, `lineno`) are required by the `warnings` library when formatting warning displays. You can find more details [here](https://docs.python.org/3/library/warnings.html#warnings.formatwarning).


In [None]:
# Defining the custom warning display
def custom_warnings(
    message: str,
    category: type,
    filename: str,
    lineno: int,
    file: Optional[TextIO] = None,
    line: Optional[str] = None,
) -> str:
    """
    Custom warning formatter.

    Args:
        message (str): The warning message.
        category (type): The category of the warning.
        filename (str): The name of the file where the warning was raised.
        lineno (int): The line number where the warning was raised.
        file (Optional[TextIO], optional): The file object to write the warning to. Default is None.
        line (Optional[str], optional): The line of code where the warning was raised. Default is None.

    Returns:
        str: The formatted warning message.
    """
    return f"{category.__name__}: {message}\n" # Return the warning name, and the message of the warning.

### Load Images and Masks

In [None]:
# Function to load in the images and associated masks
def load_images_and_masks(
    images_dir: str,
    masks_dir: str,
    file_ext: str = "tif",
    max_count: int = 100,
    trim_names: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], List[str]]:
    """
    Loads images and their corresponding masks from their respective directories.

    Args:
        images_dir (str): Path to the directory containing the images.
        masks_dir (str): Path to the directory containing the masks.
        file_ext (str, optional): The file extension of the images and masks. Default is "tif".
        max_count (int, optional): The maximum number of image-mask pairs to load. Default is 100 images.
        trim_names (bool, optional): Whether to trim the image names. Default is True.

    Returns:
        Tuple[Dict[str, Any], Dict[str, Any], List[str]]: A tuple containing:
            - images (Dict[str, Any]): Dictionary containing the images. Keys are image names.
            - masks (Dict[str, Any]): Dictionary containing the masks. Keys are image names.
            - missing_masks (List[str]): List containing the names of images with no associated masks.
    """
    # Normalize file extension to lowercase and strip any surrounding whitespace
    file_ext = file_ext.strip().lower()

    # Get the list of image files in the directory with the specified file extension
    images_dir_list = list(Path(images_dir).glob("*." + file_ext))
    if not images_dir_list:
        raise ValueError(
            f"No images of file type '{file_ext}' found in directory '{images_dir}'."
        )

    # Limit the number of images to max_count if specified. If max_count is greater than the number of available images, it will just get all images. 
    if max_count and max_count >= 1:
        images_dir_list = images_dir_list[:max_count]

    # Convert masks directory to Path object
    masks_dir = Path(masks_dir)

    # Initialize dictionaries to store images and masks, and a list for missing masks
    images = {}
    masks = {}
    missing_masks = []

    # Iterate over the list of image files
    for image_path in images_dir_list:
        file_name = image_path.stem # Get the file name without extension and without preceding directories.
        mask_path = masks_dir / f"{file_name}.{file_ext}" # Construct the mask file path

        # If the mask exists, load the image and mask, and store them in the dictionaries
        if mask_path.exists():
            image = Image.open(str(image_path))
            mask = Image.open(str(mask_path))
            images[file_name] = np.array(image)
            masks[file_name] = np.array(mask)
        else:
            # If the mask does not exist, add the file name to the missing masks list
            missing_masks.append(str(file_name))

    # Print the missing masks if any
    if missing_masks:
        print(f"The following masks are missing: {', '.join(missing_masks)}\n")
        print("The images associated with the missing masks will not be included.\n")

    # Optionally trim the names of the images and masks
    if trim_names:
        updated_names = []
        buffer_digits = 1 # Extra digits to account for in the trimmed names
        image_names, images = zip(*images.items())
        _, masks = zip(*masks.items())

        # Extract numeric part from the image names
        for image_name in image_names:
            str_match = re.search(r"\d+$", image_name)
            (
                updated_names.append(str_match.group())
                if str_match
                else updated_names.append(image_name)
            )
        
        # Determine the maximum number in the numeric part for padding purposes
        numeric_file_names = [
            int(file_name) for file_name in updated_names if file_name.isdigit()
        ]
        max_num = max(numeric_file_names)
        char_count = len(str(max_num)) + buffer_digits

        # Trim the image names to the appropriate length
        image_names = [file_name[-char_count:] for file_name in updated_names]
        
        # Update the dictionaries with the trimmed names
        images = dict(zip(image_names, images))
        masks = dict(zip(image_names, masks))
    return (images, masks, missing_masks)


### Remap Mask Classes

The `remap_mask_classes` is a public function that remaps the mask classes to a subset of the original classes. 

In [None]:
# Function to remap the masks to a subset of the original classes
def remap_mask_classes(masks, class_mapping):
    """
    Remaps the masks to the new classes based on the provided class mapping.

    Args:
        masks (Dict[str, Any]): Dictionary containing the masks. Keys are mask names.
        class_mapping (Dict[int, int]): Dictionary mapping old class indices to new class indices. The keys are the old indices, and the values are the new indices.

    Returns:
        Dict[str, Any]: The remapped masks.
    """
    
    # Get the masks as a list so that we can check how many classes there are
    mask_names, masks_list = zip(*masks.items())
    masks_list = np.stack(masks_list, axis=0)

    # Get the unique class indices present in the masks
    unique_classes = np.unique(masks_list)

    # Get the maximum class index for the old classes in the mapping
    max_class = max(class_mapping.keys())

    # Identify any classes that are out of range of the provided class mapping
    classes_out_of_range = unique_classes[unique_classes > max_class]

    # Handle classes that are out of range of the mapping
    if len(classes_out_of_range) > 0:
        choose_class = (
            input(
                f"There are classes found in the masks, that are not accounted for in the mapping: [{', '.join(map(str, classes_out_of_range))}].\n"
                "Would you like to map all out of range classes to a one of the new classes?\n"
                "(Note: If you select yes you will then be prompted for a value. If you select no the program will terminate.):\n"
                "(Y/N): "
            )
            .strip()
            .lower()
        )
        if choose_class == "yes" or choose_class == "y":
            # Prompt the user to enter a new class index to map the out-of-range classes to
            default_class = int(
                input(
                    f"Please enter the new class index to map old indices [{', '.join(map(str, classes_out_of_range))}] to.\n"
                    f"You're options are [{', '.join(map(str, np.unique(list(class_mapping.values()))))}].\n"
                    f"New index: "
                )
            )
            # Update the class mapping dictionary to include the out-of-range classes
            for class_idx in classes_out_of_range:
                class_mapping[class_idx] = default_class
        else:
            # Exit the program if the user chooses not to map the out-of-range classes
            sys.exit()

    # Create an array version of the class mapping to apply to the masks
    max_idx = max(class_mapping.keys())
    mapping_array = np.zeros(max_idx + 1, dtype=np.uint16)
    for old_idx, new_idx in class_mapping.items():
        mapping_array[old_idx] = new_idx

    # Apply the mapping to each mask
    for mask_name, mask in masks.items():
        masks[mask_name] = mapping_array[mask]

    # Calculate the number of unique new classes
    num_classes = len(set(class_mapping.values()))
    return masks, num_classes

### View Data

The `view_data` function is a public function that is called in `main.ipynb` to view images and/or masks. It calls the internal function [`_display_samples`](#display-samples) to actually generate the necessary plots. 

In [None]:
# Function to view the images, masks, or both 
def view_data(
    images: Optional[Dict[str, Any]] = None,
    masks: Optional[Dict[str, Any]] = None,
    max_plots: int = 10,
    max_cols: int = 5,
    randomize: bool = False,
    colors: bool = True,
) -> None:
    """
    View images or masks. Random subset of images can be viewed in color or grey.

    Args:
        images (Optional[Dict[str, Any]]): Dictionary of images. Keys are image names.
        masks (Optional[Dict[str, Any]]): Dictionary of masks. Keys are mask names.
        max_plots (int, optional): Maximum number of images or masks to show. Default is 10.
        max_cols (int, optional): Maximum number of columns in the display grid. Default is 5.
        randomize (bool, optional): Option to select a random subset from the images/masks. Default is False.
        colors (bool, optional): Option to display the images or masks in color. Default is True.

    Returns:
        None: This function does not return anything.
    """

    # Verify that at least images or masks are provided.
    if images is None and masks is None:
        raise ValueError("You must provide either 'images' and/or 'masks' to view.")

    # Ensure that the maximum number of columns is a reasonable value
    if max_cols < 1 or max_cols > 10:
        print("max_cols should be between 1 and 10. Defaulting to 5.")
        max_cols = 5

    # Ensure that the maximum number of plots is a reasonable value
    if max_plots < 1:
        print("max_plots should be 1 or greater. Defaulting to 1.")
        max_plots = 1

    # If images are provided, prepare to display them
    if images is not None:
        sample_type = 0 # Indicate that we are displaying images
        sample_names = list(images.keys())
        sample_count = len(sample_names)

        # Adjust max_plots if it exceeds the number of available samples
        if max_plots > sample_count:
            print(
                f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). All samples will be displayed."
            )
            max_plots = sample_count

        # Select a subset of sample names to display, either an ordered subset or a random subset
        if not randomize:
            sample_names = sample_names[:max_plots]
        else:
            sample_names = random.sample(sample_names, max_plots)

        # Display the selected images
        _display_samples(
            images, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )

    # If masks are provided, prepare to display them
    if masks is not None:
        sample_type = 1 # Indicate that we are displaying masks

        # If no images are provided, determine the sample names, max_plots, and randomization from the function inputs. If images are provided then the information in this `if` statement, is already captured. 
        if images is None:
            sample_names = list(masks.keys())
            sample_count = len(sample_names)

            # Adjust max_plots if it exceeds the number of available samples
            if max_plots > sample_count:
                print(
                    f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). Will display all samples."
                )
                max_plots = sample_count

            # Select a subset of sample names to display, either an ordered subset or a random subset
            if not randomize:
                sample_names = sample_names[:max_plots]
            else:
                sample_names = random.sample(sample_names, max_plots)

        # Display the selected images
        _display_samples(
            masks, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )


### Preprocess Images and Masks
The `preprocess_images_and_masks` function is a public function that preprocesses the images and masks to prepare them for the U-Net. It then displays some details about the preprocessed images and masks, for verification. It specifically returns to the preprocessed images, the preprocessed masks, the threshold value used for binary masks, the image names, and the number of classes in the preprocessed masks. 

In [None]:
# Function to preprocess the images and masks into the right structures for the Tensorflow U-Net.
def preprocess_images_and_masks(
    images: Dict[str, Any],
    masks: Dict[str, Any],
    num_classes: int = 3,
    target_size: Tuple[int, int] = (256, 256),
    threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, float, List[str], int]:
    """
    Preprocesses the images and masks into the correct format for use in the U-Net.

    Args:
        images (Dict[str, Any]): Dictionary of images. Keys are image names.
        masks (Dict[str, Any]): Dictionary of masks. Keys are mask names.
        num_classes (int,optional): Number of classes in the masks. Default is assumed to be 3, but will be checked against masks input.
        target_size (Tuple[int, int],optional): Target size of the images. Default is assumed to be (256,256)
        threshold (float, optional): Threshold to  use if the number of classes is 1 (binary).(binary). Default is 0.5.

    Returns:
        Tuple: A tuple containing:
            - images (np.ndarray): processed images as numpy arrays.
            - masks (np.ndarray): processed masks as numpy arrays.
            - threshold (float): Threshold to  use if the number of classes is 1 (binary).
            - image_names (List[str]): Image names to be used in plotting labels
            - num_classes (int): Number of classes in the masks.
    """
    # Get the images and their names
    image_names, images = zip(*images.items())

    # Convert the images to the appropriate format
    images = tf.stack(images)
    images = tf.cast(images, dtype=tf.float32) / 255
    images = tf.image.resize(images, target_size, method=tf.image.ResizeMethod.BILINEAR)
    images = images.numpy()

    # Get the masks and their names
    mask_names, masks = zip(*masks.items())
    # Conver the masks to a stack for easier manipulation
    masks = tf.stack(masks)
    # Check the number of classes.
    unique_classes = np.unique(masks)
    # Get the full list of the possible classes. Note this is necessary for one-hot
    unique_class_range = list(range(max(unique_classes) + 1))
    num_unique_classes = len(unique_class_range)
    if num_unique_classes < 20:
        # If the number of classes is greater than the number of unique classes give the user the option to just use the number of unique classes. If the number of classes is less than the number of unique classes inform the user than they must have at lease the number of unique classes. The user can either use this new number as the number of classes or terminate.
        if num_unique_classes < num_classes:
            choose_classes = (
                input(
                    f"There have been {num_unique_classes} mask classes detected, which differs from the input of {num_classes} mask classes.\n"
                    f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                    "(Note: To err on the side of caution it is recommended to choose the larger number.):\n"
                    "(Y/N): "
                )
                .strip()
                .lower()
            )
            if choose_classes == "yes" or choose_classes == "y":
                num_classes = num_unique_classes

        elif num_unique_classes > num_classes:
            if not (num_unique_classes == 2 and num_classes == 1):
                warnings.warn(
                    f"Detected {num_unique_classes} classes, which is larger than the input {num_classes} classes.\n"
                    f"You must have at least the the number of possible classes, in this case {num_unique_classes}, with the following classes: "
                    f"[{', '.join(map(str, unique_class_range))}]",
                    UserWarning,
                )

                choose_classes = (
                    input(
                        f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                        "(Note: If you select no, the program will terminate.):\n"
                        "(Y/N): "
                    )
                    .strip()
                    .lower()
                )
                if choose_classes == "yes" or choose_classes == "y":
                    num_classes = num_unique_classes
                else:
                    sys.exit()
    elif num_classes != 1:
        print(
            f"The image appears to be a grey-scale image because it appears to have more than 20 'classes': {num_unique_classes}.\n"
            f"We will treat this as grey-scale image, to be binary sorted into foreground and background using the threshold: {threshold}."
        )
        num_classes = 1

    # Initialize one hot encoding as not having already occurred.
    already_one_hot = False
    if len(masks.shape) == 4:
        if masks.shape[-1] == num_classes:
            verify = (
                input(
                    f"You're masks already have {num_classes} channels.\n"
                    "Do these channels represent the number of classes?\n"
                    "(Note: selecting YES assumes the data is already one-hot encoded)\n"
                    "(Y/N):"
                )
                .strip()
                .lower()
            )
            if verify == "yes" or verify == "y":
                print(
                    f"Will assume that the number of channels == the number of classes == {num_classes}.\n"
                )
                already_one_hot = True
            else:
                first_channel = (
                    input(
                        "Do you want to only use the first channel?\n"
                        "(Note: If you select NO, the program will terminate.)\n"
                        "(Y/N):"
                    )
                    .strip()
                    .lower()
                )
                if first_channel == "yes" or first_channel == "y":
                    warnings.warn(
                        f"Expected {num_classes} classes, but found {masks.shape[-1]}.\n"
                        "Only using the first channel.\n",
                        UserWarning,
                    )
                    masks = masks[:, :, :, 0]
                    masks = tf.expand_dims(masks, axis=-1)
                else:
                    sys.exist()

        elif masks.shape[-1] != 1:
            raise ValueError(
                f"Unexpected number of channels ({masks.shape[-1]}).\n"
                f"Expected either 1 (where the single channel contains all classes) or {num_classes} (where the number of channels == the number of classes).\n"
            )

    elif len(masks.shape) == 3:
        masks = tf.expand_dims(masks, axis=-1)
    else:
        raise ValueError(
            f"Unexpected number of dimensions for masks: {len(masks.shape)}.\n"
        )

    masks = tf.cast(masks, dtype=tf.uint8)
    masks = tf.image.resize(
        masks, target_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    masks = tf.cast(masks, dtype=tf.uint8)
    if not already_one_hot:
        masks = masks[:, :, :, 0]
        if num_classes == 1:
            threshold_255 = int(threshold * 255)
            masks = tf.cast(masks > threshold_255, tf.int32)
        elif num_classes == 2:
            print(
                "Data has only two classes.\n"
                "Masks will be converted to binary masks for faster evaluation."
            )
            foreground = max(unique_class_range)
            masks = tf.where(masks == foreground, 1, 0)
            num_classes = 1
        else:
            masks = tf.cast(masks, tf.int32)
            masks = tf.one_hot(masks, depth=num_classes)

    masks = tf.cast(masks, tf.float32)
    masks = masks.numpy()

    print("Preprocessed data information:")
    print(f"Number of images: {len(images)}.")
    print(f"Shape of images dataset: {images.shape}.")
    print(f"Images type: {images.dtype}.")
    print(f"Number of masks: {len(masks)}.")
    print(f"Shape of masks dataset: {masks.shape}.")
    print(f"Masks type: {masks.dtype}.")
    print(f"Number of classes: {num_classes}.")

    return (images, masks, threshold, image_names, num_classes)
