# Utility Functions

## Overview
This notebook contains utility functions used prior to the implementation of the U-Net. All functions contain doc strings which define the function inputs and outputs, and their data types. There are also comments throughout each function for additional clarity.

## Structure
- **[Import Necessary Packages](#import-necessary-packages)**: Importing the necessary packages used by functions in this notebook.
- **[Internal Functions](#internal-functions)**: Internal or private functions that are only called by other functions and should never be called directly by a script. 
    - [`_display_samples`]: Function to plot the given data, whether images or masks.
- **[Public Functions](#public-functions)**: Public functions to be imported and called by other scripts, outside of this notebook. 
    - [`custom_warnings`](#custom-warnings): Function to format warnings in the desired manner
    - [`load_images_and_masks`](#load-images-and-masks): Function to load images and masks from their given directories.
    - [`remap_mask_classes`](#remap-mask-classes): Function to remap mask classes to a subset of the classes. 
    - [`view_data`](#view-data): Function to view images and/or masks. Calls [`_display_samples`](#display-samples).
    - [`preprocess_images_and_masks`](#preprocess-images-and-masks): Function to preprocess and format the data for the U-Net.
    - [`save_datasets`](#save-datasets): Function to save pre-split datasets into individual training, test, and validation `.npy` files.
    - [`load_datasets`](#load-datasets): Function to load in the pre-split datasets, saved by [`save_datasets`](#save-datasets).
    -  [`save_associated_files`](#save_associated_files): Function to copy exisitng files and save them to another folder. This function is generally intended for use with predicted masks, and associated files that contain geospatial information for the original masks


## Usage
This notebook is not intended to be run individually. The public functions in this notebook should be imported into the `main.ipynb` using the `import-ipynb` package. The purpose of this notebook is to clearly separate and define each utility function that is called internally (in this notebook by one of the other functions) and each utility function that is called in `main.ipynb`.

## Import Necessary Packages

We first import the necessary packages used by the functions in this notebook. These libraries should already be installed using the [`requirements.txt`](requirements.txt). 

In [None]:
# Necessary imports
import os
import sys
import warnings
import shutil
from pathlib import Path
import numpy as np
import re
from PIL import Image
import random
import matplotlib.pyplot as plt
from typing import Optional, Dict, List, Tuple, Any, TextIO

# Optional line to suppress unnecessary tensorflow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

## Internal Functions

This section contains any internal, or private functions. These functions are ones that are called by other functions, and are not intended for use in any other manner. 

### Display Samples

The function `_display_samples` is called in [View Data](#view-data). This function takes in a give set of samples, either images or masks, and then plots them in the desired format. This function is separated out as an internal function because it is called multiple times in [View Data](#view-data), and it is better practice to create a separate function for repeated code.


#### _display_samples
```
def _display_samples(
    samples: Dict[str, Any],
    sample_names: List[str],
    sample_type: int,
    max_cols: int,
    max_plots: int,
    colors: bool,
    scale: float = 1,
) -> None:
```
#### Description

Displays the samples by plotting them.

#### Parameters
- `samples` (Dict[str, Any]): 

    Dictionary containing the samples. Keys are sample names.

- `sample_names` (List[str]): 

    List of sample names.

- `sample_type` (int): 

    Type of sample, either 1 (masks) or 0 (images).

- `max_cols` (int): 

    Maximum number of columns in the figure.

- `max_plots` (int): 

    Maximum number of plots to display.

- `colors` (bool): 

    Whether to use colors in the plots.

- `scale` (float, optional): 

    Scale factor for the plot size. Default is 1.


#### Example
```
import matplotlib.pyplot as plt
import numpy as np

# Example data
samples = {
    "sample1": np.random.rand(10, 10),
    "sample2": np.random.rand(10, 10),
    "sample3": np.random.rand(10, 10),
    "sample4": np.random.rand(10, 10)
}
sample_names = ["sample1", "sample2", "sample3"]
sample_type = 1  # Mask type
max_cols = 2
max_plots = 3
colors = True
scale = 1.5

# Display the samples
_display_samples(samples, sample_names, sample_type, max_cols, max_plots, colors, scale)
```
#### Notes
- The function determines the number of columns and rows needed for the plots based on max_cols and max_plots.
- It creates a figure and a grid of subplots using matplotlib.
- The function chooses the color map based on the colors flag. If colors is True, it uses the "viridis" color map; otherwise, it uses the "grey" color map.
- The function sets the title of the figure to "Masks" if sample_type is 1; otherwise, it sets the title to "Images".
- For each sample, the function plots the sample in the correct subplot, using color or greyscale based on the colors flag.
- The function removes any unused axes and adjusts the layout to prevent overlap.

In [None]:
# Internal function used by view_data to actually display the figures for a given set of image or mask data. 
def _display_samples(
    samples: Dict[str, Any],
    sample_names: List[str],
    sample_type: int,
    max_cols: int,
    max_plots: int,
    colors: bool,
    scale: float = 1,
) -> None:
    """
    Displays the samples by plotting them.

    Args:
        samples (Dict[str, Any]): Dictionary containing the samples. Keys are sample names.
        sample_names (List[str]): List of sample names.
        sample_type (int): Type of sample, either 1 or 0.
        max_cols (int): Maximum number of columns in the plot.
        max_plots (int): Maximum number of plots to display.
        colors (bool): Whether to use colors in the plots.
        scale (float, optional): Scale factor for the plot size. Default is 1.

    Returns:
        None: This function does not return anything. It displays the plots.
    """

    # Determine the number of columns and rows needed for the plots
    num_cols = min(max_cols, max_plots)
    num_rows = (max_plots + max_cols - 1) // max_cols

    # Create a figure and a grid of subplots
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * scale, num_rows * scale)
    )
    axes = np.array(axes).reshape(-1) # Flatten the axes array for easier iteration

    # Choose the color map based on the 'colors' flag
    if colors:
        cmap = "viridis"
        if sample_type == 1:
            # Since the sample_type is 1 our samples are masks 
            fig.suptitle("Masks")

            # For each mask, scale the values based on the value range so that it may be plotted in color. Then plot the mask in the correct subplot
            for idx, sample in enumerate(sample_names):
                max_value = samples[sample].max() if samples[sample].max() else 1 # Set the max_value to 1 if it is 0, to prevent division by 0
                axes[idx].imshow(samples[sample] / max_value, cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")
        else:
            # Since the sample_type is not 1 our samples are images 
            fig.suptitle("Images")

            # For each image, plot the image in color, in the correct subplot
            for idx, sample in enumerate(sample_names):
                axes[idx].imshow(samples[sample], cmap=cmap)
                axes[idx].set_title(f"{sample}")
                axes[idx].axis("off")

    else:
        cmap = "grey"
        # If the sample_type is 1 we have masks, otherwise we have images as our samples
        fig.suptitle("Masks") if sample_type == 1 else fig.suptitle("Images")

        # For each sample (image or mask), plot the sample in greyscale in the correct subplot.
        for idx, sample in enumerate(sample_names):
            axes[idx].imshow(samples[sample], cmap=cmap)
            axes[idx].set_title(f"{sample}")
            axes[idx].axis("off")

    # Remove any unused axes
    for empty_idx in range(idx + 1, len(axes)):
        fig.delaxes(axes[empty_idx])

    # Adjust layout to prevent overlap and display the plots
    plt.tight_layout()
    plt.show(block=False)

    return


## Public Functions

The functions in this section are public functions which can be called or passed outside this notebook, and are intended for use in the `main.ipynb`. 

### Custom Warnings
The `custom_warnings` function is technically a public function, that is passed in `main.ipynb` to python's built in `warnings` library to tell it how we'd like our warnings displayed. However, we never actually **call** the function. **This function is deceptive, in it's use.** In `main.ipynb` the function itself is **passed** to `warnings` instead of **called**. It is also deceptive in that only the `category` and `message` inputs are explicitly used in our `custom_warnings` function. The remaining required inputs (`filename`, `lineno`) are required by the `warnings` library when formatting warning displays. You can find more details [here](https://docs.python.org/3/library/warnings.html#warnings.formatwarning).


#### custom_warnings
```
def custom_warnings(
    message: str,
    category: type,
    filename: str,
    lineno: int,
    file: Optional[TextIO] = None,
    line: Optional[str] = None,
) -> str:
```
#### Description

Custom warning formatter.

#### Parameters
- `message` (str): 

    The warning message.

- `category` (type): 

    The category of the warning.

- `filename` (str): 

    The name of the file where the warning was raised.

- `lineno` (int): 

    The line number where the warning was raised.

- `file` (Optional[TextIO], optional): 

    The file object to write the warning to. Default is None.

- `line` (Optional[str], optional): 

    The line of code where the warning was raised. Default is None.

#### Returns

- `str`: 
    The formatted warning message.

#### Example
```
import warnings

# Define the custom warning display function
def custom_warnings(
    message: str,
    category: type,
    filename: str,
    lineno: int,
    file: Optional[TextIO] = None,
    line: Optional[str] = None,
) -> str:
    return f"{category.__name__}: {message}\n"

# Set the custom warning display function
warnings.formatwarning = custom_warnings

# Trigger a warning to see the custom format
warnings.warn("This is a custom warning message", UserWarning)
```
#### Notes
- The function formats the warning message by including the category name and the message.
- The formatted message is returned as a string.
- This function can be set as the warning formatter using `warnings.formatwarning = custom_warnings`.


In [None]:
# Defining the custom warning display
def custom_warnings(
    message: str,
    category: type,
    filename: str,
    lineno: int,
    file: Optional[TextIO] = None,
    line: Optional[str] = None,
) -> str:
    """
    Custom warning formatter.

    Args:
        message (str): The warning message.
        category (type): The category of the warning.
        filename (str): The name of the file where the warning was raised.
        lineno (int): The line number where the warning was raised.
        file (Optional[TextIO], optional): The file object to write the warning to. Default is None.
        line (Optional[str], optional): The line of code where the warning was raised. Default is None.

    Returns:
        str: The formatted warning message.
    """
    return f"{category.__name__}: {message}\n" # Return the warning name, and the message of the warning.

### Load Images and Masks

The `load_images_and_masks` function is a public function, that is called in `main.ipynb` to load in the images and masks from their respective directories. 


#### load_images_and_masks
```
def load_images_and_masks(
    images_dir: str,
    masks_dir: str,
    file_ext: str = "tif",
    max_count: int = 100,
    trim_names: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], List[str], Union[Dict[str, str], None]]:
```
#### Description

Loads images and their corresponding masks from their respective directories.

#### Parameters
- `images_dir` (str): 

    Path to the directory containing the images.

- `masks_dir` (str): 

    Path to the directory containing the masks.

- `file_ext` (str, optional): 

    The file extension of the images and masks. Default is "tif".

- `max_count` (int, optional): 

    The maximum number of image-mask pairs to load. Default is 100.

- `trim_names` (bool, optional): 

    Whether to trim the image names. Default is True.


#### Returns

- Tuple[Dict[str, Any], Dict[str, Any], List[str]]: 

    A tuple containing:

    - `images` (Dict[str, Any]): 
    
        Dictionary containing the images. Keys are image names.

    - `masks` (Dict[str, Any]): 
    
        Dictionary containing the masks. Keys are image names.

    - `missing_masks` (List[str]): 
    
        List containing the names of images with no associated masks.
    
    - `names_map` (Union[Dict[str, str], None]): 
    
        If `trim_names` was `True` then it returns a dictionary with the new names as keys, and the old names as values. This can be used in the `save_predicted_masks` function from the [`unet.ipynb`](unet.ipynb) to get other associated files with the original name, such as `.twf` files which contain geospatial information.

#### Raises
- `ValueError`: If no images of the specified file type are found in the directory.

#### Example
```
from pathlib import Path
from PIL import Image
import numpy as np

# Example usage
images_dir = "path/to/images"
masks_dir = "path/to/masks"
file_ext = "tif"
max_count = 50
trim_names = True

# Load images and masks
images, masks, missing_masks, names_map = load_images_and_masks(
    images_dir, masks_dir, file_ext, max_count, trim_names
)

# Display the loaded images and masks
print(f"Loaded {len(images)} images and {len(masks)} masks.")
if missing_masks:
    print(f"Missing masks for images: {', '.join(missing_masks)}")
```
#### Notes
- The function normalizes the file extension to lowercase and strips any surrounding whitespace.
- It retrieves the list of image files in the directory with the specified file extension.
- The function limits the number of images to max_count if specified.
- It initializes dictionaries to store images and masks, and a list for missing masks.
- For each image, the function checks if the corresponding mask exists and loads both if available.
- If a mask is missing, the image name is added to the missing_masks list, and the image is not loaded.
- The function optionally trims the image names based on the trim_names flag.
- The trimmed names are updated in the dictionaries for images and masks.
A message is printed if any masks are missing.



In [None]:
# Function to load in the images and associated masks
def load_images_and_masks(
    images_dir: str,
    masks_dir: str,
    file_ext: str = "tif",
    max_count: int = 100,
    trim_names: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], List[str]]:
    """
    Loads images and their corresponding masks from their respective directories.

    Args:
        images_dir (str): Path to the directory containing the images.
        masks_dir (str): Path to the directory containing the masks.
        file_ext (str, optional): The file extension of the images and masks. Default is "tif".
        max_count (int, optional): The maximum number of image-mask pairs to load. Default is 100 images.
        trim_names (bool, optional): Whether to trim the image names. Default is True.

    Returns:
        Tuple[Dict[str, Any], Dict[str, Any], List[str]]: A tuple containing:
            - images (Dict[str, Any]): Dictionary containing the images. Keys are image names.
            - masks (Dict[str, Any]): Dictionary containing the masks. Keys are image names.
            - missing_masks (List[str]): List containing the names of images with no associated masks.
            - `names_map` (Dict[str, str]): If `trim_names` was `True` then it returns a dictionary with the new names as keys, and the old names as values. Otherwise it returns an empty dictionary.
    """
    # Normalize file extension to lowercase and strip any surrounding whitespace
    file_ext = file_ext.strip().lower()

    # Get the list of image files in the directory with the specified file extension
    images_dir_list = list(Path(images_dir).glob("*." + file_ext))
    if not images_dir_list:
        raise ValueError(
            f"No images of file type '{file_ext}' found in directory '{images_dir}'."
        )

    # Limit the number of images to max_count if specified. If max_count is greater than the number of available images, it will just get all images. 
    if max_count and max_count >= 1:
        images_dir_list = images_dir_list[:max_count]

    # Convert masks directory to Path object
    masks_dir = Path(masks_dir)

    # Initialize dictionaries to store images and masks, a list for missing masks, and a none 
    images = {}
    masks = {}
    missing_masks = []
    names_map = {}

    # Iterate over the list of image files
    for image_path in images_dir_list:
        file_name = image_path.stem # Get the file name without extension and without preceding directories.
        mask_path = masks_dir / f"{file_name}.{file_ext}" # Construct the mask file path

        # If the mask exists, load the image and mask, and store them in the dictionaries
        if mask_path.exists():
            image = Image.open(str(image_path))
            mask = Image.open(str(mask_path))
            images[file_name] = np.array(image)
            masks[file_name] = np.array(mask)
        else:
            # If the mask does not exist, add the file name to the missing masks list
            missing_masks.append(str(file_name))

    # Print the missing masks if any
    if missing_masks:
        print(f"The following masks are missing: {', '.join(missing_masks)}\n")
        print("The images associated with the missing masks will not be included.\n")

    # Optionally trim the names of the images and masks
    if trim_names:
        updated_names = []
        buffer_digits = 1 # Extra digits to account for in the trimmed names
        old_image_names, images = zip(*images.items())
        _, masks = zip(*masks.items())

        # Extract numeric part from the image names
        for image_name in old_image_names:
            str_match = re.search(r"\d+$", image_name)
            (
                updated_names.append(str_match.group())
                if str_match
                else updated_names.append(image_name)
            )
        
        # Determine the maximum number in the numeric part for padding purposes
        numeric_file_names = [
            int(file_name) for file_name in updated_names if file_name.isdigit()
        ]
        max_num = max(numeric_file_names)
        char_count = len(str(max_num)) + buffer_digits

        # Trim the image names to the appropriate length
        image_names = [file_name[-char_count:] for file_name in updated_names]
        
        # Update the dictionaries with the trimmed names
        images = dict(zip(image_names, images))
        masks = dict(zip(image_names, masks))
        names_map = dict(zip(image_names, old_image_names))
    return (images, masks, missing_masks, names_map)


### Remap Mask Classes
The `remap_mask_classes` is a public function that remaps the mask classes to a subset of the original classes. 

#### remap_mask_classes
```
def remap_mask_classes(
    masks: Dict[str, Any], class_mapping: Dict[int, int]
) -> Dict[str, Any]:
```
#### Description

Remaps the masks to the new classes based on the provided class mapping.

#### Parameters
- `masks` (Dict[str, Any]): 

    Dictionary containing the masks. Keys are mask names.

- `class_mapping` (Dict[int, int]): 

    Dictionary mapping old class indices to new class indices. The keys are the old indices, and the values are the new indices.


#### Returns

- Dict[str, Any]: The remapped masks.

#### Example
```
import numpy as np

# Example masks and class mapping
masks = {
    "mask1": np.array([[0, 1, 2], [2, 1, 0]]),
    "mask2": np.array([[1, 2, 3], [3, 2, 1]])
}
class_mapping = {0: 0, 1: 1, 2: 2, 3: 1}

# Remap the mask classes
remapped_masks, num_classes = remap_mask_classes(masks, class_mapping)

# Display the remapped masks and number of new classes
print(f"Remapped Masks: {remapped_masks}")
print(f"Number of Unique New Classes: {num_classes}")
```
#### Notes
- The function retrieves the masks as a list to check the number of classes present.
- It identifies any classes that are out of range of the provided class mapping.
- If there are classes out of range, the user is prompted to map these classes to a new class index or terminate the program.
- The function creates an array version of the class mapping to apply to the masks.
- It applies the mapping to each mask and calculates the new number of classes.
- The function returns the remapped masks and the new number of classes.


In [None]:
# Function to remap the masks to a subset of the original classes
def remap_mask_classes(
    masks: Dict[str, Any], class_mapping: Dict[int, int]
) -> Dict[str, Any]:
    """
    Remaps the masks to the new classes based on the provided class mapping.

    Args:
        masks (Dict[str, Any]): Dictionary containing the masks. Keys are mask names.
        class_mapping (Dict[int, int]): Dictionary mapping old class indices to new class indices. The keys are the old indices, and the values are the new indices.

    Returns:
        Dict[str, Any]: The remapped masks.
    """
    
    # Get the masks as a list so that we can check how many classes there are
    mask_names, masks_list = zip(*masks.items())
    masks_list = np.stack(masks_list, axis=0)

    # Get the unique class indices present in the masks
    unique_classes = np.unique(masks_list)

    # Get the maximum class index for the old classes in the mapping
    max_class = max(class_mapping.keys())

    # Identify any classes that are out of range of the provided class mapping
    classes_out_of_range = unique_classes[unique_classes > max_class]

    # Handle classes that are out of range of the mapping
    if len(classes_out_of_range) > 0:
        choose_class = (
            input(
                f"There are classes found in the masks, that are not accounted for in the mapping: [{', '.join(map(str, classes_out_of_range))}].\n"
                "Would you like to map all out of range classes to a one of the new classes?\n"
                "(Note: If you select yes you will then be prompted for a value. If you select no the program will terminate.):\n"
                "(Y/N): "
            )
            .strip()
            .lower()
        )
        if choose_class == "yes" or choose_class == "y":
            # Prompt the user to enter a new class index to map the out-of-range classes to
            default_class = int(
                input(
                    f"Please enter the new class index to map old indices [{', '.join(map(str, classes_out_of_range))}] to.\n"
                    f"You're options are [{', '.join(map(str, np.unique(list(class_mapping.values()))))}].\n"
                    f"New index: "
                )
            )
            # Update the class mapping dictionary to include the out-of-range classes
            for class_idx in classes_out_of_range:
                class_mapping[class_idx] = default_class
        else:
            # Exit the program if the user chooses not to map the out-of-range classes
            sys.exit()

    # Create an array version of the class mapping to apply to the masks
    max_idx = max(class_mapping.keys())
    mapping_array = np.zeros(max_idx + 1, dtype=np.uint16)
    for old_idx, new_idx in class_mapping.items():
        mapping_array[old_idx] = new_idx

    # Apply the mapping to each mask
    for mask_name, mask in masks.items():
        masks[mask_name] = mapping_array[mask]

    # Calculate the number of unique new classes
    num_classes = len(set(class_mapping.values()))
    return masks, num_classes

### View Data

The `view_data` function is a public function that is called in `main.ipynb` to view images and/or masks. It calls the internal function [`_display_samples`](#display-samples) to actually generate the necessary plots. 

#### view_data
```
def view_data(
    images: Optional[Dict[str, Any]] = None,
    masks: Optional[Dict[str, Any]] = None,
    max_plots: int = 10,
    max_cols: int = 5,
    randomize: bool = False,
    colors: bool = True,
) -> None:
```
#### Description

View images or masks. A random subset of images can be viewed in color or grey.

#### Parameters
- `images` (Optional[Dict[str, Any]]): Dictionary of images. Keys are image names.
- `masks` (Optional[Dict[str, Any]]): Dictionary of masks. Keys are mask names.
- `max_plots` (int, optional): Maximum number of images or masks to show. Default is 10.
- `max_cols` (int, optional): Maximum number of columns in the display grid. Default is 5.
- `randomize` (bool, optional): Option to select a random subset from the images/masks. Default is False.
- `colors` (bool, optional): Option to display the images or masks in color. Default is True.

#### Raises
- `ValueError`: If neither images nor masks are provided.

#### Example
```
import numpy as np

# Example images and masks
images = {
    "image1": np.random.rand(10, 10, 3),
    "image2": np.random.rand(10, 10, 3),
    "image3": np.random.rand(10, 10, 3)
}
masks = {
    "mask1": np.random.randint(0, 2, (10, 10)),
    "mask2": np.random.randint(0, 2, (10, 10)),
    "mask3": np.random.randint(0, 2, (10, 10))
}

# View the images and masks
view_data(images=images, masks=masks, max_plots=3, max_cols=2, randomize=True, colors=True)
```
#### Notes
- The function verifies that at least images or masks are provided.
- It ensures that the maximum number of columns is between 1 and 10.
- It ensures that the maximum number of plots is at least 1.
- If images are provided, the function prepares to display them, adjusting max_plots if it exceeds the number of available samples.
- The function selects a subset of sample names to display, either in order or randomly.
- It calls the `_display_samples` function to display the selected images or masks.
- If masks are provided, the function prepares to display them similarly to images.
- The function handles cases where images and masks are both provided or only one of them is provided.



In [None]:
# Function to view the images, masks, or both 
def view_data(
    images: Optional[Dict[str, Any]] = None,
    masks: Optional[Dict[str, Any]] = None,
    max_plots: int = 10,
    max_cols: int = 5,
    randomize: bool = False,
    colors: bool = True,
) -> None:
    """
    View images or masks. Random subset of images can be viewed in color or grey.

    Args:
        images (Optional[Dict[str, Any]]): Dictionary of images. Keys are image names.
        masks (Optional[Dict[str, Any]]): Dictionary of masks. Keys are mask names.
        max_plots (int, optional): Maximum number of images or masks to show. Default is 10.
        max_cols (int, optional): Maximum number of columns in the display grid. Default is 5.
        randomize (bool, optional): Option to select a random subset from the images/masks. Default is False.
        colors (bool, optional): Option to display the images or masks in color. Default is True.

    Returns:
        None: This function does not return anything.
    """

    # Verify that at least images or masks are provided.
    if images is None and masks is None:
        raise ValueError("You must provide either 'images' and/or 'masks' to view.")

    # Ensure that the maximum number of columns is a reasonable value
    if max_cols < 1 or max_cols > 10:
        print("max_cols should be between 1 and 10. Defaulting to 5.")
        max_cols = 5

    # Ensure that the maximum number of plots is a reasonable value
    if max_plots < 1:
        print("max_plots should be 1 or greater. Defaulting to 1.")
        max_plots = 1

    # If images are provided, prepare to display them
    if images is not None:
        sample_type = 0 # Indicate that we are displaying images
        sample_names = list(images.keys())
        sample_count = len(sample_names)

        # Adjust max_plots if it exceeds the number of available samples
        if max_plots > sample_count:
            print(
                f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). All samples will be displayed."
            )
            max_plots = sample_count

        # Select a subset of sample names to display, either an ordered subset or a random subset
        if not randomize:
            sample_names = sample_names[:max_plots]
        else:
            sample_names = random.sample(sample_names, max_plots)

        # Display the selected images
        _display_samples(
            images, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )

    # If masks are provided, prepare to display them
    if masks is not None:
        sample_type = 1 # Indicate that we are displaying masks

        # If no images are provided, determine the sample names, max_plots, and randomization from the function inputs. If images are provided then the information in this `if` statement, is already captured. 
        if images is None:
            sample_names = list(masks.keys())
            sample_count = len(sample_names)

            # Adjust max_plots if it exceeds the number of available samples
            if max_plots > sample_count:
                print(
                    f"max_plots ({max_plots}) exceeds the number of samples ({sample_count}). Will display all samples."
                )
                max_plots = sample_count

            # Select a subset of sample names to display, either an ordered subset or a random subset
            if not randomize:
                sample_names = sample_names[:max_plots]
            else:
                sample_names = random.sample(sample_names, max_plots)

        # Display the selected images
        _display_samples(
            masks, sample_names, sample_type, max_cols, max_plots, colors, scale=1
        )


### Preprocess Images and Masks
The `preprocess_images_and_masks` function is a public function that preprocesses the images and masks to prepare them for the U-Net. This includes verify the number of classes, and the format of the image and mask data. It is capable of handling cases where the number of classes in the masks does not match the input number of classes, and it will request information from the user in those cases. It then displays some details about the preprocessed images and masks, for verification. It specifically returns to the preprocessed images, the preprocessed masks, the threshold value used for binary masks, the image names, and the number of classes in the preprocessed masks. 

#### preprocess_images_and_masks
```
def preprocess_images_and_masks(
    images: Dict[str, Any],
    masks: Dict[str, Any],
    num_classes: int = 3,
    target_size: Tuple[int, int] = (256, 256),
    threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, float, List[str], int]:
```
#### Description

Preprocesses the images and masks into the correct format for use in the U-Net.

#### Parameters
- `images` (Dict[str, Any]): Dictionary of images. Keys are image names.
- `masks` (Dict[str, Any]): Dictionary of masks. Keys are mask names.
- `num_classes` (int, optional): Number of classes in the masks. Default is 3.
- `target_size` (Tuple[int, int], optional): Target size of the images. Default is (256, 256).
- `threshold` (float, optional): Threshold to use if the number of classes is 1 (binary). Default is 0.5.

#### Returns
- Tuple[np.ndarray, np.ndarray, float, List[str], int]: 

    A tuple containing:

    - `images` (np.ndarray): 
    
        Processed images as numpy arrays.

    - `masks` (np.ndarray): 
    
        Processed masks as numpy arrays.

    - `threshold` (float): 
    
        Threshold to use if the number of classes is 1 (binary).

    - `image_names` (List[str]): 
    
        Image names to be used in plotting labels.

    - `num_classes` (int): 
    
        Number of classes in the masks.

#### Raises
- `ValueError`: If the number of channels in the masks does not match the expected number of classes.
- `ValueError`: If the masks have an unexpected number of dimensions.

#### Example
```
import numpy as np
import tensorflow as tf
from PIL import Image

# Example images and masks
images = {
    "image1": np.random.rand(256, 256, 3),
    "image2": np.random.rand(256, 256, 3),
    "image3": np.random.rand(256, 256, 3)
}
masks = {
    "mask1": np.random.randint(0, 3, (256, 256)),
    "mask2": np.random.randint(0, 3, (256, 256)),
    "mask3": np.random.randint(0, 3, (256, 256))
}

# Preprocess the images and masks
processed_images, processed_masks, threshold, image_names, num_classes = preprocess_images_and_masks(
    images, masks, num_classes=3, target_size=(256, 256), threshold=0.5
)
```
#### Notes
- The function converts the images to float32 numpy arrays, scales them to the range [0, 1], and resizes them to the target size.
- It converts the masks to a stack for easier manipulation and checks the number of unique classes.
- The function handles cases where the number of unique classes differs from the input `num_classes`.
- It checks if the masks are already one-hot encoded and handles cases where they are not.
- The function resizes the masks to the target size and converts them to float32 numpy arrays.
- It prints information about the preprocessed data, including the number of images, shape of the dataset, and number of classes.





In [None]:
# Function to preprocess the images and masks into the right structures for the Tensorflow U-Net.
def preprocess_images_and_masks(
    images: Dict[str, Any],
    masks: Dict[str, Any],
    num_classes: int = 3,
    target_size: Tuple[int, int] = (256, 256),
    threshold: float = 0.5,
) -> Tuple[np.ndarray, np.ndarray, float, List[str], int]:
    """
    Preprocesses the images and masks into the correct format for use in the U-Net.

    Args:
        images (Dict[str, Any]): Dictionary of images. Keys are image names.
        masks (Dict[str, Any]): Dictionary of masks. Keys are mask names.
        num_classes (int,optional): Number of classes in the masks. Default is assumed to be 3, but will be checked against masks input.
        target_size (Tuple[int, int],optional): Target size of the images. Default is assumed to be (256,256)
        threshold (float, optional): Threshold to  use if the number of classes is 1 (binary).(binary). Default is 0.5.

    Returns:
        Tuple: A tuple containing:
            - images (np.ndarray): processed images as numpy arrays.
            - masks (np.ndarray): processed masks as numpy arrays.
            - threshold (float): Threshold to  use if the number of classes is 1 (binary).
            - image_names (List[str]): Image names to be used in plotting labels
            - num_classes (int): Number of classes in the masks.
    """
    # Get the images and their names
    image_names, images = zip(*images.items())

    # Convert the images to the appropriate format, in this case a float32 numpy array
    images = tf.stack(images)
    images = tf.cast(images, dtype=tf.float32) / 255 # Scale the images down from RGB 255 range to 0 to 1 range
    images = tf.image.resize(images, target_size, method=tf.image.ResizeMethod.BILINEAR) # Resize the images to the desired shape
    images = images.numpy()

    # Get the masks and their names
    mask_names, masks = zip(*masks.items())

    # Conver the masks to a stack for easier manipulation
    masks = tf.stack(masks)

    # Check the number of mask classes
    unique_classes = np.unique(masks)

    # Get the full list of the possible classes. Note this is necessary for proper one-hot encoding
    unique_class_range = list(range(max(unique_classes) + 1))
    num_unique_classes = len(unique_class_range)

    # Handle the case where the number of unique classes differs from the input num_classes
    if num_unique_classes < 20:
        # If the number of detected classes is less than the number of input classes you can optional choose to use the detected number of classes input, or choose to continue with the input number of classes. 
        if num_unique_classes < num_classes:
            choose_classes = (
                input(
                    f"There have been {num_unique_classes} mask classes detected, which differs from the input of {num_classes} mask classes.\n"
                    f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                    "(Note: To err on the side of caution it is recommended to choose the larger number.):\n"
                    "(Y/N): "
                )
                .strip()
                .lower()
            )
            if choose_classes == "yes" or choose_classes == "y":
                num_classes = num_unique_classes

        # If the detected number of classes detected is larger than the number of input classes, then you must either opt to use detected number of classes or terminate the program.
        elif num_unique_classes > num_classes:
            if not (num_unique_classes == 2 and num_classes == 1):
                warnings.warn(
                    f"Detected {num_unique_classes} classes, which is larger than the input {num_classes} classes.\n"
                    f"You must have at least the the number of possible classes, in this case {num_unique_classes}, with the following classes: "
                    f"[{', '.join(map(str, unique_class_range))}]",
                    UserWarning,
                )

                choose_classes = (
                    input(
                        f"Would you like to use the detected number of classes {num_unique_classes}?\n"
                        "(Note: If you select no, the program will terminate.):\n"
                        "(Y/N): "
                    )
                    .strip()
                    .lower()
                )
                if choose_classes == "yes" or choose_classes == "y":
                    num_classes = num_unique_classes
                else:
                    sys.exit()
    # If the detected number of classes is greater than 20, and the input number of classes was not already set to 1, it indicates the image is likely greyscale (values ranging between 0 and 1). In this case we will automatically treat it as such, with a single class.
    elif num_classes != 1:
        print(
            f"The image appears to be a grey-scale image because it appears to have more than 20 'classes': {num_unique_classes}.\n"
            f"We will treat this as grey-scale image, to be binary sorted into foreground and background using the threshold: {threshold}."
        )
        num_classes = 1

    
    already_one_hot = False
    # Handle cases where the masks may or may not already be one hot encoded (4th dimension of a value other than 1)
    if len(masks.shape) == 4:

        # If the mask shape looks like it is already one hot encoded (the 4th dimension matches the number of classes), ask the user to verify this as correct or to let us know that this is wrong. If it is not one hot encoded but has multiple channels, then it will select only the first channel, based on the assumption that all the channels contain the same information. 
        if masks.shape[-1] == num_classes:
            verify = (
                input(
                    f"You're masks already have {num_classes} channels.\n"
                    "Do these channels represent the number of classes?\n"
                    "(Note: selecting YES assumes the data is already one-hot encoded)\n"
                    "(Y/N):"
                )
                .strip()
                .lower()
            )

            # If the user says yes, use the channels and assume already one-hot encoded
            if verify == "yes" or verify == "y":
                print(
                    f"Will assume that the number of channels == the number of classes == {num_classes}.\n"
                )
                already_one_hot = True

            # If the user doesn't say yes, verify the user wants to use the first channel.
            else:
                first_channel = (
                    input(
                        "Do you want to only use the first channel?\n"
                        "(Note: If you select NO, the program will terminate.)\n"
                        "(Y/N):"
                    )
                    .strip()
                    .lower()
                )

                # If the user says yes, use the first channel, assuming they are all the same.
                if first_channel == "yes" or first_channel == "y":
                    warnings.warn(
                        f"Expected {num_classes} classes, but found {masks.shape[-1]}.\n"
                        "Only using the first channel.\n",
                        UserWarning,
                    )
                    masks = masks[:, :, :, 0]
                    masks = tf.expand_dims(masks, axis=-1)

                # If the user doesn't say yes, terminate the program.
                else:
                    sys.exist()

        # If the number of channels (the 4th dimension of the mask) does not match the number of classes and it is not a single channel (which indicates either greyscale or that the classes are all captured in the single channel by having discrete values in the channel such as 0,1,2,3 etc.), raise an error.
        elif masks.shape[-1] != 1:
            raise ValueError(
                f"Unexpected number of channels ({masks.shape[-1]}).\n"
                f"Expected either 1 (where the single channel contains all classes) or {num_classes} (where the number of channels == the number of classes).\n"
            )
    # If the masks does not have a 4th dimension, then just add the 4th dimension to achieve the necessary shape for one-hot encoding.
    elif len(masks.shape) == 3:
        masks = tf.expand_dims(masks, axis=-1)
    # If the masks is not shape 3 or 4, then raise an error.
    else:
        raise ValueError(
            f"Unexpected number of dimensions for masks: {len(masks.shape)}.\n"
        )

    # Convert the masks to an int type for resizing, and then resize to desired shape
    masks = tf.cast(masks, dtype=tf.uint8)
    masks = tf.image.resize(
        masks, target_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    masks = tf.cast(masks, dtype=tf.uint8)

    # Check if the masks are already one-hot encoded
    if not already_one_hot:
        masks = masks[:, :, :, 0]
        # If there is only one class, then treat it as binary case (foreground and background)
        if num_classes == 1:
            threshold_255 = int(threshold * 255)
            masks = tf.cast(masks > threshold_255, tf.int32)

        # If the number of classes is 2, then still treat it as binary case, but inform the user of the change (foreground and background)
        elif num_classes == 2:
            print(
                "Data has only two classes.\n"
                "Masks will be converted to binary masks for faster evaluation."
            )
            foreground = max(unique_class_range) # Use the larger value as the foreground (1) class.
            masks = tf.where(masks == foreground, 1, 0)
            num_classes = 1
        else:
            masks = tf.cast(masks, tf.int32)
            masks = tf.one_hot(masks, depth=num_classes)

    # Convert the masks the appropriate format, in this case a float32 numpy array
    masks = tf.cast(masks, tf.float32)
    masks = masks.numpy()

    # Print information about the preprocessed data
    print("Preprocessed data information:")
    print(f"Number of images: {len(images)}.")
    print(f"Shape of images dataset: {images.shape}.")
    print(f"Images type: {images.dtype}.")
    print(f"Number of masks: {len(masks)}.")
    print(f"Shape of masks dataset: {masks.shape}.")
    print(f"Masks type: {masks.dtype}.")
    print(f"Number of classes: {num_classes}.")

    return (images, masks, threshold, image_names, num_classes)


### Save Datasets

The `save_datasets` function is a public function called in `main.ipynb`, to save  split datasets into individual `.npy` files, for later use. 

#### save_datasets
```
def save_datasets(
    training: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    test: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    validation: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    save_dir: str = "Datasets",
) -> None:
```
#### Description

Saves the training, test, and validation datasets to the specified directory.

#### Parameters
- `save_dir` (str, optional): 

    Path to the directory where the datasets are saved. Default is "Datasets"

- `training` Optional[Tuple[np.ndarray, np.ndarray]]: 

    Training dataset set to be saved, images and then masks. Default is None.

- `test` Optional[Tuple[np.ndarray, np.ndarray]]: 

    Test dataset set to be saved, images and then masks. Default is None.

- `validation` Optional[Tuple[np.ndarray, np.ndarray]]: 

    Validation dataset set to be saved, images and then masks. Default is None.


#### Raises
- `ValueError`: If none of the datasets (training, test, validation) are provided to save.

#### Example
```
import numpy as np

# Example data
training_data = (np.random.rand(10, 10), np.random.rand(10, 10))
test_data = (np.random.rand(10, 10), np.random.rand(10, 10))
validation_data = (np.random.rand(10, 10), np.random.rand(10, 10))

# Save the datasets to the specified directory
save_datasets(
    training=training_data,
    test=test_data,
    validation=validation_data,
    save_dir="my_datasets"
)
```
#### Notes
- Ensure that at least one of the datasets (training, test, validation) is provided to save.
- The function converts the `save_dir` to a `Path` object for easier path manipulations.
- A message is printed to indicate the save directory once the data is saved.


In [None]:
def save_datasets(
    training: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    test: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    validation: Optional[Tuple[np.ndarray, np.ndarray]] = None,
    save_dir: str = "Datasets",
) -> None:
    """
    Saves the training, test, and validation datasets to the specified directory.

    Args:
        training (Optional[Tuple[np.ndarray, np.ndarray]], optional): Tuple containing training images and masks. Default is None.
        test (Optional[Tuple[np.ndarray, np.ndarray]], optional): Tuple containing test images and masks. Default is None.
        validation (Optional[Tuple[np.ndarray, np.ndarray]], optional): Tuple containing validation images and masks. Default is None.
        save_dir (str, optional): Path to the directory where the datasets will be saved. Default is "Datasets".

    Returns:
        None: This function does not return anything.
    """

    # Ensure that at least one of the datasets (training, test, validation) is provided
    if not (training and test and validation):
        raise ValueError(
            "You must provide at least one of the three data types to save: training, test, or validation."
        )

    # Create the save directory if it doesn't exist
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Save the training dataset if provided
    if training:
        np.save(save_dir / "images_training.npy", training[0])
        np.save(save_dir / "masks_training.npy", training[1])

    # Save the test dataset if provided
    if test:
        np.save(save_dir / "images_test.npy", test[0])
        np.save(save_dir / "masks_test.npy", test[1])

    # Save the validation dataset if provided
    if validation:
        np.save(save_dir / "images_validation.npy", validation[0])
        np.save(save_dir / "masks_validation.npy", validation[1])

    print(f"Data saved in {save_dir}.")

    return

### Load Datasets

The `load_datasets` function is a public function called in `main.ipynb`, to load a preexisting presplit dataset that was saved in the format used by [`save_datasets`](#save-datasets). 

#### load_datasets
```
def load_datasets(
    load_dir: str,
    training: bool = True,
    test: bool = True,
    validation: bool = True,
) -> Tuple[
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
]:
```
#### Description

Loads the training, test, and validation datasets from the specified directory.

#### Parameters
- `load_dir` (str): 

    Path to the directory where the datasets are stored.

- `training` (bool, optional): 

    Whether to load the training dataset. Default is True.

- `test` (bool, optional): 

    Whether to load the test dataset. Default is True.

- `validation` (bool, optional): 

    Whether to load the validation dataset. Default is True.

#### Returns
- Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]

    A tuple containing:

    - `images_training` (Optional[np.ndarray]):

        The training images.

    - `masks_training` (Optional[np.ndarray]): 
    
        The training masks.

    - `images_test` (Optional[np.ndarray]): 

        The test images.

    - `masks_test` (Optional[np.ndarray]): 

        The test masks.

    - `images_validation` (Optional[np.ndarray]):

        The validation images.

    - `masks_validation` (Optional[np.ndarray]): 

        The validation masks.

#### Raises
- `ValueError`: If none of the datasets (training, test, validation) is selected to load.

#### Example
```
from pathlib import Path
import numpy as np

# Load datasets from the specified directory
images_training, masks_training, images_test, masks_test, images_validation, masks_validation = load_datasets(
    load_dir="path/to/datasets",
    training=True,
    test=True,
    validation=True
)
```
#### Notes
- Ensure that at least one of the datasets (training, test, validation) is selected to load.
- The function converts the `load_dir` to a `Path` object for easier path manipulations.
- If a dataset type is not selected, its corresponding return value will be `None`.


In [None]:
def load_datasets(
    load_dir: str,
    training: bool = True,
    test: bool = True,
    validation: bool = True,
) -> Tuple[
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
    Optional[np.ndarray],
]:
    """
    Loads the training, test, and validation datasets from the specified directory.

    Args:
        load_dir (str): Path to the directory where the datasets are stored.
        training (bool, optional): Whether to load the training dataset. Default is True.
        test (bool, optional): Whether to load the test dataset. Default is True.
        validation (bool, optional): Whether to load the validation dataset. Default is True.


    Returns:
        Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
        A tuple containing:
            - images_training (Optional[np.ndarray]): The training images.
            - masks_training (Optional[np.ndarray]): The training masks.
            - images_test (Optional[np.ndarray]): The test images.
            - masks_test (Optional[np.ndarray]): The test masks.
            - images_validation (Optional[np.ndarray]): The validation images.
            - masks_validation (Optional[np.ndarray]): The validation masks.
    """
    # Ensure that at least one of the datasets (training, test, validation) is selected to load
    if not (training and test and validation):
        raise ValueError(
            "You must select at least one of the three data types to load: training, test, or validation."
        )
    load_dir = Path(load_dir) # Convert the load directory to a Path object

    # Initialize variables for datasets
    images_training, masks_training = None, None
    images_test, masks_test = None, None
    images_validation, masks_validation = None, None

    # Load the training dataset if selected
    if training:
        images_training = np.load(load_dir / "images_training.npy")
        masks_training = np.load(load_dir / "masks_training.npy")

    # Load the test dataset if selected
    if test:
        images_test = np.load(load_dir / "images_test.npy")
        masks_test = np.load(load_dir / "masks_test.npy")

    # Load the validation dataset if selected
    if validation:
        images_validation = np.load(load_dir / "images_validation.npy")
        masks_validation = np.load(load_dir / "masks_validation.npy")

    return (
        images_training,
        masks_training,
        images_test,
        masks_test,
        images_validation,
        masks_validation,
    )


### Save Associated Files

The `save_associated_files` function is a public function called in `main.ipynb` to copy files from one directory and save them to another. This is relevant in the case of `.tif` files with associated `.tfw` that contain necessary geospatial information. When the predicted masks are created, their geospatial data does not change, so we would want to include those files with predicted masks. 

Note that **if a `names_map` is provided, for each file found based on a given value from the dictionary, the file will be renamed to the associated key from the dictionary**. For example if I had a key-value pair of `{new_name: original_name}`, then the file `original_name` found in the `original_dir` will be copied to `output_dir` and saved under the name `new_name`. **This means that if there are any existing files in the folder, with the same name and file type they will be overwritten.** 

#### save_associated_files

```
def save_associated_files(
    original_dir: str,
    output_dir: str,
    file_ext: str = "tfw",
    file_names: Optional[List[str]] = None,
    names_map: Optional[Dict[str, str]] = None,
) -> None:
```
#### Description

Copies associated files from the original directory to the output directory.

#### Parameters
- `original_dir` (str): 

    Path to the directory containing the original files.

- `output_dir` (str): 

    Path to the directory where the files will be saved.

- `file_ext` (str, optional): 

    File extension of the associated files. Default is `tfw`.

- `file_names` (Optional[List[str]], optional): 

    List of file names to copy. If `None`, all files with the specified extension will be copied. Default is `None`.

- `names_map` (Optional[Dict[str, str]], optional): 

    Dictionary mapping file names to their corresponding names in the original directory. Default is None.

#### Raises
- `ValueError`: If no files of the specified file type are found in the directory.

#### Example
```
import shutil
from pathlib import Path

# Example usage
original_dir = "path/to/original"
output_dir = "path/to/output"
file_ext = "tfw"
file_names = ["file1", "file2", "file3"]
names_map = {"file1": "original_file1", "file2": "original_file2", "file3": "original_file3"}

# Save the associated files
save_associated_files(original_dir, output_dir, file_ext, file_names, names_map)
```
#### Notes
- The function creates the output_dir if it does not exist.
- It normalizes the file extension to lowercase and strips any surrounding whitespace.
- The function retrieves the list of files in the original directory with the specified file extension.
- If `file_names` is provided, it copies the specified files. If `names_map` is provided, it uses the mapping to find the corresponding file names in the original directory, and when saving renames it to file name from the `names_map` key.
- If `file_names` is not provided, it copies all files with the specified extension.
- The function prints a message if any files are missing from the original directory or the names_map.
- A message is printed to indicate the output directory once the files are saved.

In [None]:
def save_associated_files(
    original_dir: str,
    output_dir: str,
    file_ext: str = "tfw",
    file_names: Optional[List[str]] = None,
    names_map: Optional[Dict[str, str]] = None,
) -> None:
    """
    Copies associated files from the original directory to the output directory.

    Args:
        original_dir (str): Path to the directory containing the original files.
        output_dir (str): Path to the directory where the files will be saved.
        file_ext (str, optional): File extension of the associated files. Default is "tfw".
        file_names (Optional[List[str]], optional): List of file names to copy. If None, all files with the specified extension will be copied. Default is None.
        names_map (Optional[Dict[str, str]], optional): Dictionary mapping file names to their corresponding names in the original directory. Default is None.

    Returns:
        None: This function does not return anything.
    """
    # Create the output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Convert original directory to Path object
    original_dir = Path(original_dir)

    # Normalize file extension to lowercase and strip any surrounding whitespace
    file_ext = file_ext.strip().lower()

    # Initialize empty missing files and keys
    missing_files = []
    missing_keys = []

    # Get the list of files in the original directory with the specified file extension
    original_files_list = list(Path(original_dir).glob("*." + file_ext))

    # If no files with the given extension are found in the original directory, raise an error.
    if not original_files_list:
        raise ValueError(
            f"No files of file type '{file_ext}' found in directory '{original_dir}'."
        )

    # If specific file names are provided, then either match them to the names_map if provided, or directly search for them in the original directory.
    if file_names:
        if names_map:
            # If a names_map is provided, map the file names accordingly
            keys = [key for key in file_names if key in names_map]
            missing_keys = list(set(file_names) - set(keys))  # Identify missing keys
            for key in keys:
                file_name = names_map[key]
                src_file = original_dir / f"{file_name}.{file_ext}"
                output_file = output_dir / f"{key}.{file_ext}"
                if src_file.exists():
                    shutil.copy2(
                        src_file, output_file
                    )  # Copy the file to the output directory, with the same name as the key 
                else:
                    missing_files.append(str(file_name))  # Track missing files
        else:
            # If no names_map is provided, use the file names directly
            for file_name in file_names:
                src_file = original_dir / f"{file_name}.{file_ext}"
                if src_file.exists():
                    shutil.copy2(
                        src_file, output_dir
                    )  # Copy the file to the output directory
                else:
                    missing_files.append(str(file_name))  # Track missing files
    else:
        # If no specific file names are provided, copy all files with the specified extension
        for src_file in original_files_list:
            shutil.copy2(src_file, output_dir)

    # Print missing files if any
    if missing_files:
        print(
            f"The following files are missing from {original_dir}: \n{', '.join(missing_files)}\n"
        )
    # Print missing keys if any
    if missing_keys:
        print(
            f"The following files are missing from `names_map`: \n{', '.join(missing_keys)}\n"
        )

    print(f"Files saved to {output_dir}.")
    return
