In [None]:
# Necessary imports
import os
import sys
import warnings
from pathlib import Path
import numpy as np
import re
from PIL import Image
import random
import matplotlib.pyplot as plt
from typing import Optional, Dict, List, Tuple, Any, TextIO

# Optional line to suppress unnecessary tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

In [None]:
# Internal function used by view_data to actually display the figures for a given set of image or mask data. 
def _display_samples(
    samples: Dict[str, Any],
    sample_names: List[str],
    sample_type: int,
    max_cols: int,
    max_plots: int,
    colors: bool,
    scale: float = 1,
) -> None:
    """
    Displays the samples by plotting them.

    Args:
        samples (Dict[str, Any]): Dictionary containing the samples. Keys are sample names.
        sample_names (List[str]): List of sample names.
        sample_type (int): Type of sample, either 1 or 0.
        max_cols (int): Maximum number of columns in the plot.
        max_plots (int): Maximum number of plots to display.
        colors (bool): Whether to use colors in the plots.
        scale (float, optional): Scale factor for the plot size. Default is 1.

    Returns:
        None: This function does not return anything. It displays the plots.
    """
    num_cols = min(max_cols, max_plots)
    num_rows = (max_plots + max_cols - 1) // max_cols
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * scale, num_rows * scale)
    )
    axes = np.array(axes).reshape(-1)
    if colors:
        cmap = "viridis"
        if sample_type == 1:
            fig.suptitle("Masks")
            for idx, sample in enumerate(sample_names):
                max_value = samples[sample].max() if samples[sample].max() else 1
                axes[idx].imshow(samples[sample] / max_value, cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")
        else:
            fig.suptitle("Images")
            for idx, sample in enumerate(sample_names):
                axes[idx].imshow(samples[sample], cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")

    else:
        cmap = "grey"
        fig.suptitle("Masks") if sample_type == 1 else fig.suptitle("Images")
        for idx, sample in enumerate(sample_names):
            axes[idx].imshow(samples[sample], cmap=cmap)
            axes[idx].set_title(f"{sample}")
            axes[idx].axis("off")

    for empty_idx in range(idx + 1, len(axes)):
        fig.delaxes(axes[empty_idx])

    plt.tight_layout()
    plt.show(block=False)

    return


In [None]:
# Defining the custom warning display
def custom_warnings(
    message: str,
    category: type,
    filename: str,
    lineno: int,
    file: Optional[TextIO] = None,
    line: Optional[str] = None,
) -> str:
    """
    Custom warning formatter.

    Args:
        message (str): The warning message.
        category (type): The category of the warning.
        filename (str): The name of the file where the warning was raised.
        lineno (int): The line number where the warning was raised.
        file (Optional[TextIO], optional): The file object to write the warning to. Default is None.
        line (Optional[str], optional): The line of code where the warning was raised. Default is None.

    Returns:
        str: The formatted warning message.
    """
    return f"{category.__name__}: {message}\n"

In [None]:
# Function to load in the images and associated masks
def load_images_and_masks(
    images_dir: str,
    masks_dir: str,
    file_ext: str = "tif",
    max_count: int = 100,
    trim_names: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], List[str]]:
    """
    Loads images and their corresponding masks from their respective directories.

    Args:
        images_dir (str): Path to the directory containing the images.
        masks_dir (str): Path to the directory containing the masks.
        file_ext (str, optional): The file extension of the images and masks. Default is "tif".
        max_count (int, optional): The maximum number of image-mask pairs to load. Default is 100 images.
        trim_names (bool, optional): Whether to trim the image names. Default is True.

    Returns:
        Tuple[Dict[str, Any], Dict[str, Any], List[str]]: A tuple containing:
            - images (Dict[str, Any]): Dictionary containing the images. Keys are image names.
            - masks (Dict[str, Any]): Dictionary containing the masks. Keys are image names.
            - missing_masks (List[str]): List containing the names of images with no associated masks.
    """
    # Get the images from the directory, selecting all the .png files only
    file_ext = file_ext.strip().lower()
    images_dir_list = list(Path(images_dir).glob("*." + file_ext))
    if not images_dir_list:
        raise ValueError(
            f"No images of file type '{file_ext}' found in directory '{images_dir}'."
        )

    # If  there is a maximum number of images select upto that number. Note that if there are fewer images in the directory, it will just select all images.
    if max_count and max_count >= 1:
        images_dir_list = images_dir_list[:max_count]
    # Only gets the directory for the masks. This is because we don't yet know if there are masks for each image.
    masks_dir = Path(masks_dir)
    # Initialize empty variables for good practice.
    images = {}
    masks = {}
    missing_masks = []
    # For each image, check if there is a corresponding mask. If so, add the images and masks to their respective dictionaries, where the keys are the file names. If not, add the file name to the missing masks list, and do not add anything to images or masks. At the end, if there are missing masks, let the user know which ones.
    for image_path in images_dir_list:
        file_name = image_path.stem
        mask_path = masks_dir / f"{file_name}.{file_ext}"
        if mask_path.exists():
            image = Image.open(str(image_path))
            mask = Image.open(str(mask_path))
            images[file_name] = np.array(image)
            masks[file_name] = np.array(mask)
        else:
            missing_masks.append(str(file_name))
    if missing_masks:
        print(f"The following masks are missing: {', '.join(missing_masks)}\n")
        print("The images associated with the missing masks will not be included.\n")

    if trim_names:
        updated_names = []
        buffer_digits = 1
        image_names, images = zip(*images.items())
        _, masks = zip(*masks.items())
        for image_name in image_names:
            str_match = re.search(r"\d+$", image_name)
            (
                updated_names.append(str_match.group())
                if str_match
                else updated_names.append(image_name)
            )
        numeric_file_names = [
            int(file_name) for file_name in updated_names if file_name.isdigit()
        ]
        max_num = max(numeric_file_names)
        char_count = len(str(max_num)) + buffer_digits
        image_names = [file_name[-char_count:] for file_name in updated_names]
        images = dict(zip(image_names, images))
        masks = dict(zip(image_names, masks))
    return (images, masks, missing_masks)


In [None]:
# Function to remap the masks to a subset of the original classes
def remap_mask_classes(masks, class_mapping):
    """
    Remaps the masks to the new classes based on the provided class mapping.

    Args:
        masks (Dict[str, Any]): Dictionary containing the masks. Keys are mask names.
        class_mapping (Dict[int, int]): Dictionary mapping old class indices to new class indices. The keys are the old indices, and the values are the new indices.

    Returns:
        Dict[str, Any]: The remapped masks.
    """
    # Get the masks as a list so that we can check how many classes there are
    mask_names, masks_list = zip(*masks.items())
    masks_list = np.stack(masks_list, axis=0)
    # Get the unique class indices
    unique_classes = np.unique(masks_list)
    # Get the maximum class index for the old classes in the mapping
    max_class = max(class_mapping.keys())

    # Get any classes that are out of range of the map
    classes_out_of_range = unique_classes[unique_classes > max_class]

    if len(classes_out_of_range) > 0:
        choose_class = (
            input(
                f"There are classes found in the masks, that are not accounted for in the mapping: [{', '.join(map(str, classes_out_of_range))}].\n"
                "Would you like to map all out of range classes to a one of the new classes?\n"
                "(Note: If you select yes you will then be prompted for a value. If you select no the program will terminate.):\n"
                "(Y/N): "
            )
            .strip()
            .lower()
        )
        if choose_class == "yes" or choose_class == "y":
            default_class = int(
                input(
                    f"Please enter the new class index to map old indices [{', '.join(map(str, classes_out_of_range))}] to.\n"
                    f"You're options are [{', '.join(map(str, np.unique(list(class_mapping.values()))))}].\n"
                    f"New index: "
                )
            )
            # Updating the class mapping dictionary
            for class_idx in classes_out_of_range:
                class_mapping[class_idx] = default_class
        else:
            sys.exit()

    # Array version of the map to apply to the masks
    max_idx = max(class_mapping.keys())
    mapping_array = np.zeros(max_idx + 1, dtype=np.uint16)
    for old_idx, new_idx in class_mapping.items():
        mapping_array[old_idx] = new_idx

    for mask_name, mask in masks.items():
        masks[mask_name] = mapping_array[mask]
    num_classes = len(set(class_mapping.values()))
    return masks, num_classes

In [None]:
# Function to view the images, masks, or both 
def view_data(
    images: Optional[Dict[str, Any]] = None,
    masks: Optional[Dict[str, Any]] = None,
    max_plots: int = 10,
    max_cols: int = 5,
    randomize: bool = False,
    colors: bool = True,
) -> None:
    """
    View images or masks. Random subset of images can be viewed in color or grey.

    Args:
        images (Optional[Dict[str, Any]]): Dictionary of images. Keys are image names.
        masks (Optional[Dict[str, Any]]): Dictionary of masks. Keys are mask names.
        max_plots (int, optional): Maximum number of images or masks to show. Default is 10.
        max_cols (int, optional): Maximum number of columns in the display grid. Default is 5.
        randomize (bool, optional): Option to select a random subset from the images/masks. Default is False.
        colors (bool, optional): Option to display the images or masks in color. Default is True.

    Returns:
        None: This function does not return anything.
    """
    # Verify that at least images or masks are provided.
    if images is None and masks is None:
        raise ValueError("You must provide either 'images' and/or 'masks' to view.")

    # Ensure that the maximum number of columns is a reasonable value
    if max_cols < 1 or max_cols > 10:
        print("max_cols should be between 1 and 10. Defaulting to 5.")
        max_cols = 5
    # Ensure that the maximum number of plots is a reasonable value
    if max_plots < 1:
        print("max_plots should be 1 or greater. Defaulting to 1.")
        max_plots = 1

    if images is not None:
        sample_type = 0
        sample_names = list(images.keys())
        sample_count = len(sample_names)
        if max_plots > sample_count:
            print(
                f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). Will display all samples."
            )
            max_plots = sample_count

        if not randomize:
            sample_names = sample_names[:max_plots]
        else:
            sample_names = random.sample(sample_names, max_plots)

        _display_samples(
            images, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )

    if masks is not None:
        sample_type = 1
        if images is None:
            sample_names = list(masks.keys())
            sample_count = len(sample_names)
            if max_plots > sample_count:
                print(
                    f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). Will display all samples."
                )
                max_plots = sample_count

            if not randomize:
                sample_names = sample_names[:max_plots]
            else:
                sample_names = random.sample(sample_names, max_plots)

        _display_samples(
            masks, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )


In [None]:
# Function to preprocess the images and masks into the right structures for the Tensorflow U-Net.
def preprocess_images_and_masks(
    images: Dict[str, Any],
    masks: Dict[str, Any],
    num_classes: int = 3,
    target_size: Tuple[int, int] = (256, 256),
    threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, float, List[str], int]:
    """
    Preprocesses the images into the correct format for use in the U-Net.

    Args:
        images (Dict[str, Any]): Dictionary of images. Keys are image names.
        masks (Dict[str, Any]): Dictionary of masks. Keys are mask names.
        num_classes (int,optional): Number of classes in the masks. Default is assumed to be 3, but will be checked against masks input.
        target_size (Tuple[int, int],optional): Target size of the images. Default is assumed to be (256,256)
        threshold (float, optional): Threshold to  use if the number of classes is 1 (binary).(binary). Default is 0.5.

    Returns:
        Tuple: A tuple containing:
            - images (np.ndarray): processed images as numpy arrays.
            - masks (np.ndarray): processed masks as numpy arrays.
            - threshold (float): Threshold to  use if the number of classes is 1 (binary).
            - image_names (List[str]): Image names to be used in plotting labels
            - num_classes (int): Number of classes in the masks.
    """
    # Get the images and their names
    image_names, images = zip(*images.items())

    # Convert the images to the appropriate format
    images = tf.stack(images)
    images = tf.cast(images, dtype=tf.float32) / 255
    images = tf.image.resize(images, target_size, method=tf.image.ResizeMethod.BILINEAR)
    images = images.numpy()

    # Get the masks and their names
    mask_names, masks = zip(*masks.items())
    # Conver the masks to a stack for easier manipulation
    masks = tf.stack(masks)
    # Check the number of classes.
    unique_classes = np.unique(masks)
    # Get the full list of the possible classes. Note this is necessary for one-hot
    unique_class_range = list(range(max(unique_classes) + 1))
    num_unique_classes = len(unique_class_range)
    if num_unique_classes < 20:
        # If the number of classes is greater than the number of unique classes give the user the option to just use the number of unique classes. If the number of classes is less than the number of unique classes inform the user than they must have at lease the number of unique classes. The user can either use this new number as the number of classes or terminate.
        if num_unique_classes < num_classes:
            choose_classes = (
                input(
                    f"There have been {num_unique_classes} mask classes detected, which differs from the input of {num_classes} mask classes.\n"
                    f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                    "(Note: To err on the side of caution it is recommended to choose the larger number.):\n"
                    "(Y/N): "
                )
                .strip()
                .lower()
            )
            if choose_classes == "yes" or choose_classes == "y":
                num_classes = num_unique_classes

        elif num_unique_classes > num_classes:
            if not (num_unique_classes == 2 and num_classes == 1):
                warnings.warn(
                    f"Detected {num_unique_classes} classes, which is larger than the input {num_classes} classes."
                    f"You must have at least the the number of possible classes, in this case {num_unique_classes}, with the following classes: "
                    f"[{', '.join(map(str, unique_class_range))}]",
                    UserWarning,
                )

                choose_classes = (
                    input(
                        f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                        "(Note: If you select no, the program will terminate.):\n"
                        "(Y/N): "
                    )
                    .strip()
                    .lower()
                )
                if choose_classes == "yes" or choose_classes == "y":
                    num_classes = num_unique_classes
                else:
                    sys.exit()
    elif num_classes != 1:
        print(
            f"The image appears to be a grey-scale image because it appears to have more than 20 'classes': {num_unique_classes}.\n"
            f"We will treat this as grey-scale image, to be binary sorted into foreground and background using the threshold: {threshold}."
        )
        num_classes = 1

    # Initialize one hot encoding as not having already occurred.
    already_one_hot = False
    if len(masks.shape) == 4:
        if masks.shape[-1] == num_classes:
            verify = (
                input(
                    f"You're masks already have {num_classes} channels.\n"
                    "Do these channels represent the number of classes?\n"
                    "(Note: selecting YES assumes the data is already one-hot encoded)\n"
                    "(Y/N):"
                )
                .strip()
                .lower()
            )
            if verify == "yes" or verify == "y":
                print(
                    f"Will assume that the number of channels == the number of classes == {num_classes}.\n"
                )
                already_one_hot = True
            else:
                first_channel = (
                    input(
                        "Do you want to only use the first channel?\n"
                        "(Note: If you select NO, the program will terminate.)\n"
                        "(Y/N):"
                    )
                    .strip()
                    .lower()
                )
                if first_channel == "yes" or first_channel == "y":
                    warnings.warn(
                        f"Expected {num_classes} classes, but found {masks.shape[-1]}.\n"
                        "Only using the first channel.\n",
                        UserWarning,
                    )
                    masks = masks[:, :, :, 0]
                    masks = tf.expand_dims(masks, axis=-1)
                else:
                    sys.exist()

        elif masks.shape[-1] != 1:
            raise ValueError(
                f"Unexpected number of channels ({masks.shape[-1]}).\n"
                f"Expected either 1 (where the single channel contains all classes) or {num_classes} (where the number of channels == the number of classes).\n"
            )

    elif len(masks.shape) == 3:
        masks = tf.expand_dims(masks, axis=-1)
    else:
        raise ValueError(
            f"Unexpected number of dimensions for masks: {len(masks.shape)}.\n"
        )

    masks = tf.cast(masks, dtype=tf.uint8)
    masks = tf.image.resize(
        masks, target_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    masks = tf.cast(masks, dtype=tf.uint8)
    if not already_one_hot:
        masks = masks[:, :, :, 0]
        if num_classes == 1:
            threshold_255 = int(threshold * 255)
            masks = tf.cast(masks > threshold_255, tf.int32)
        elif num_classes == 2:
            print(
                "Data has only two classes.\n"
                "Masks will be converted to binary masks for faster evaluation."
            )
            foreground = max(unique_class_range)
            masks = tf.where(masks == foreground, 1, 0)
            num_classes = 1
        else:
            masks = tf.cast(masks, tf.int32)
            masks = tf.one_hot(masks, depth=num_classes)

    masks = tf.cast(masks, tf.float32)
    masks = masks.numpy()

    print("Preprocessed data information:")
    print(f"Number of images: {len(images)}.")
    print(f"Shape of images dataset: {images.shape}.")
    print(f"Images type: {images.dtype}.")
    print(f"Number of masks: {len(masks)}.")
    print(f"Shape of masks dataset: {masks.shape}.")
    print(f"Masks type: {masks.dtype}.")
    print(f"Number of classes: {num_classes}.")

    return (images, masks, threshold, image_names, num_classes)
