In [1]:
from pathlib import Path
import numpy as np
from tqdm import tqdm
import utils
import json
from os.path import join
from instance2border_core import instance2border_core
import zarr
from acvl_utils.miscellaneous.ptqdm import ptqdm
import pickle
from skimage.measure import regionprops
from typing import List, Tuple, Dict, Any
import numpy as np
import cc3d
import copy
from scipy.ndimage import label as nd_label
from skimage.morphology import footprint_rectangle, dilation
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm
import numpy_indexed as npi
from typing import Tuple, Type
from skimage.morphology import binary_erosion, dilation, ball, footprint_rectangle


def preprocess_all(load_dir: str, names: List[str], save_dir: str, target_spacing: float,
                   target_particle_size_in_pixel: int, dataset_name: str, processes: int,
                   border_thickness_in_pixel: int, gpu: bool, binary: bool, device: int) -> None:
    """
    Preprocesses all the samples in the dataset.

    :param load_dir: Path to the base directory that contains the dataset structured in the form of the directories 'images' and 'instance_seg' and the files metadata.json.
    :param names: The name(s) without extension of the image(s) that should be used for training.
    :param save_dir: Path to the preprocessed dataset directory.
    :param target_spacing: The target spacing in millimeters.
    :param target_particle_size_in_pixel: The target particle size in pixels.
    :param dataset_name: The name of the preprocessed dataset.
    :param processes: Number of processes to use for parallel processing. None to disable multiprocessing.
    :param border_thickness_in_pixel: Border thickness in pixel.
    :param gpu: Flag indicating whether to use the GPU for preprocessing.
    :param binary: Flag indicating whether to a binary mask is used instead of instance segmentation.
    :param device: Value indicating which GPU to use (0 or 1 only)
    """
    metadata_load_filepath = join(load_dir, "metadata.json")

    with open(metadata_load_filepath) as f:
        metadata = json.load(f)

    target_spacing = [target_spacing] * 3
    target_particle_size_in_pixel = [target_particle_size_in_pixel] * 3

    image_save_dir = join(save_dir, dataset_name, "imagesTr")
    semantic_seg_save_dir = join(save_dir, dataset_name, "labelsTr")
    instance_seg_save_dir = join(save_dir, dataset_name, "labelsTr_instance")
    semantic_seg_zarr_save_dir = join(save_dir, dataset_name, "labelsTr_zarr")
    instance_seg_zarr_save_dir = join(save_dir, dataset_name, "labelsTr_instance_zarr")
    Path(instance_seg_save_dir).mkdir(parents=True, exist_ok=True)
    Path(image_save_dir).mkdir(parents=True, exist_ok=True)
    Path(semantic_seg_save_dir).mkdir(parents=True, exist_ok=True)

    for name in names:
        if name not in metadata:
            raise RuntimeError("{} is missing in metadata!".format(name))

    image_load_filepaths = [join(load_dir, "images", name + ".nii.gz") for name in names]
    if binary:
        seg_load_filepaths = [join(load_dir, "binary_mask", name + ".nii.gz") for name in names]
    else:
        seg_load_filepaths = [join(load_dir, "instance_seg", name + ".nii.gz") for name in names]

    if processes is None:
        for i in tqdm(range(len(names))):
            preprocess_single(i, names=names, image_load_filepaths=image_load_filepaths, seg_load_filepaths=seg_load_filepaths,
                      metadata_load_filepath=metadata_load_filepath, image_save_dir=image_save_dir,
                      semantic_seg_save_dir=semantic_seg_save_dir, instance_seg_save_dir=instance_seg_save_dir,
                      semantic_seg_zarr_save_dir=semantic_seg_zarr_save_dir, instance_seg_zarr_save_dir=instance_seg_zarr_save_dir,
                      target_spacing=target_spacing, target_particle_size_in_pixel=target_particle_size_in_pixel,
                      border_thickness_in_pixel=border_thickness_in_pixel, gpu=gpu, binary=binary, device=device)
    else:
        ptqdm(preprocess_single, range(len(names)), processes, names=names, image_load_filepaths=image_load_filepaths,
                  seg_load_filepaths=seg_load_filepaths, metadata_load_filepath=metadata_load_filepath,
                  image_save_dir=image_save_dir, semantic_seg_save_dir=semantic_seg_save_dir,
                  instance_seg_save_dir=instance_seg_save_dir, semantic_seg_zarr_save_dir=semantic_seg_zarr_save_dir,
                  instance_seg_zarr_save_dir=instance_seg_zarr_save_dir, target_spacing=target_spacing,
                  target_particle_size_in_pixel=target_particle_size_in_pixel,
                  border_thickness_in_pixel=border_thickness_in_pixel, gpu=gpu, binary=binary, device=device)

    utils.generate_dataset_json(join(save_dir, dataset_name, 'dataset.json'), join(save_dir, dataset_name, "imagesTr"), None, ("noNorm",), {0: 'bg', 1: 'core', 2: 'border'}, dataset_name)

    gen_regionprops(join(save_dir, dataset_name, "labelsTr_instance"), join(save_dir, dataset_name, "regionprops.pkl"))


def preprocess_single(i: int,
                      names: List[str],
                      image_load_filepaths: List[str],
                      seg_load_filepaths: List[str],
                      metadata_load_filepath: str,
                      image_save_dir: str,
                      semantic_seg_save_dir: str,
                      instance_seg_save_dir: str,
                      semantic_seg_zarr_save_dir: str,
                      instance_seg_zarr_save_dir: str,
                      target_spacing: List[float],
                      target_particle_size_in_pixel: List[int],
                      border_thickness_in_pixel: int,
                      gpu: bool,
                      binary: bool,
                      device: int) -> None:
    """
    Preprocess a single 3D particle segmentation image.

    Args:
        i (int): Index of the image to preprocess.
        names (List[str]): Names of the images to preprocess.
        image_load_filepaths (List[str]): Paths to the input 3D particle segmentation image files.
        seg_load_filepaths (List[str]): Paths to the input instance segmentation or border-core representation image files.
        metadata_load_filepath (str): Path to the metadata file.
        image_save_dir (str): Path to the directory to save the preprocessed images.
        semantic_seg_save_dir (str): Path to the directory to save the semantic segmentation images.
        instance_seg_save_dir (str): Path to the directory to save the instance segmentation images.
        semantic_seg_zarr_save_dir (str): Path to the directory to save the semantic segmentation images in zarr format.
        instance_seg_zarr_save_dir (str): Path to the directory to save the instance segmentation images in zarr format.
        target_spacing (List[float]): Target spacing in millimeters.
        target_particle_size_in_pixel (List[int]): Target particle size in pixels.
        border_thickness_in_pixel (int): Border thickness in pixels.
        gpu (bool): If True, use GPU for resampling.
        binary (bool): If True, a binary mask is used instead of instance segmentation.
        device (int): Value indicating which GPU to use (0 or 1 only)

    Returns:
        None
    """
    name = names[i]
    image_load_filepath = image_load_filepaths[i]
    seg_load_filepath = seg_load_filepaths[i]

    with open(metadata_load_filepath) as f:
        metadata = json.load(f)

    image = utils.load_nifti(image_load_filepath)

    # zscore = {"mean": zscore[0], "std": zscore[1]}
    zscore = {"mean": np.mean(image), "std": np.std(image)} # Automated Z-Scoring per image
    image = utils.standardize(image, zscore)

    image_shape = image.shape
    source_particle_size_in_mm = [metadata[name]["particle_size"]] * 3
    source_spacing = [metadata[name]["spacing"]] * 3
    target_particle_size_in_mm = tuple(utils.pixel2mm(target_particle_size_in_pixel, target_spacing))

    size_conversion_factor = utils.compute_size_conversion_factor(source_particle_size_in_mm, source_spacing, target_particle_size_in_mm, target_spacing)
    target_patch_size_in_pixel = np.rint(np.asarray(image_shape) / size_conversion_factor).astype(int)
    target_patch_size_in_pixel = target_patch_size_in_pixel.tolist()

    image = utils.resample(image, target_patch_size_in_pixel, 1/(size_conversion_factor[0]), gpu=gpu, disable=True, device=device)
    patch_name = "{}".format(name)
    image_save_filepath = join(image_save_dir, patch_name + "_0000.nii.gz")
    utils.save_nifti(image_save_filepath, image, spacing=target_spacing) # Why not saved as np.uint16? Originally saving it as np.float32: I think it is because of the normalization which converts to float 3/8/25

    if binary:
        semantic_seg = utils.load_nifti(seg_load_filepath)
        semantic_seg = utils.resample(semantic_seg, target_patch_size_in_pixel, 1/(size_conversion_factor[0]), gpu=gpu, seg=True, disable=True, device=device)
        semantic_seg = generate_bordercore(semantic_seg)
        instance_seg, _ = border_core2instance(semantic_seg)
    else:
        instance_seg = utils.load_nifti(seg_load_filepath)
        instance_seg = utils.resample(instance_seg, target_patch_size_in_pixel, 1/(size_conversion_factor[0]), gpu=gpu, seg=True, disable=True, device=device)
        semantic_seg = instance2border_core(instance_seg, border_thickness_in_pixel)

    instance_seg_save_filepath = join(instance_seg_save_dir, patch_name + ".nii.gz")
    utils.save_nifti(instance_seg_save_filepath, instance_seg, spacing=target_spacing, is_seg=True, dtype=np.uint16)
    semantic_seg_save_filepath = join(semantic_seg_save_dir, patch_name + ".nii.gz")
    utils.save_nifti(semantic_seg_save_filepath, semantic_seg, spacing=target_spacing, is_seg=True, dtype=np.uint8)
    semantic_seg_zarr_save_filepath = join(semantic_seg_zarr_save_dir, patch_name + ".zarr")
    semantic_seg = zarr.array(semantic_seg)
    zarr.save(semantic_seg_zarr_save_filepath, semantic_seg, chunks=(64, 64, 64))
    instance_seg_zarr_save_filepath = join(instance_seg_zarr_save_dir, patch_name + ".zarr")
    instance_seg = zarr.array(instance_seg)
    zarr.save(instance_seg_zarr_save_filepath, instance_seg, chunks=(64, 64, 64))


def gen_regionprops(load_dir: str, metadata_filepath: str) -> None:
    """Extracts regionprops features from the given instance segmentation data and saves the data to the given file.

    Args:
        load_dir (str): Absolute path to the directory containing instance segmentation data.
        metadata_filepath (str): Absolute path to the file where the extracted regionprops features should be saved.
    """
    names = utils.load_filepaths(load_dir, return_path=False, return_extension=False)
    metadata = {}

    len_props_total, len_props_filtered_total = 0, 0
    for name in tqdm(names):
        instance_seg = utils.load_nifti(join(load_dir, name + ".nii.gz"))
        props, len_props, len_props_filtered = gen_regionprops_single(instance_seg)
        len_props_total += len_props
        len_props_filtered_total += len_props_filtered
        metadata[name] = props

    with open(metadata_filepath, 'wb') as handle:
        pickle.dump(metadata, handle, protocol=pickle.HIGHEST_PROTOCOL)


def gen_regionprops_single(instance_seg: np.ndarray) -> Tuple[Dict[int, Tuple[int, int, int, int, int, int]], int, int]:
    """Extracts regionprops features from a single instance segmentation volume.

    Args:
        instance_seg (np.ndarray): A 3D numpy array containing the instance segmentation data.

    Returns:
        A tuple containing the extracted regionprops features, the number of total regionprops, and the number of filtered regionprops.
    """
    props = {prop.label: prop.bbox for prop in regionprops(instance_seg)}

    props_filtered = {}
    image_shape = instance_seg.shape
    for label, bbox in props.items():
        bbox_reshaped = [[bbox[i], bbox[i + len(bbox) // 2]] for i in range(len(bbox) // 2)]
        bbox_reshaped = np.asarray(bbox_reshaped)
        ok = True
        for axis in range(len(image_shape)):
            if bbox_reshaped[axis][0] == 0 or bbox_reshaped[axis][1] == image_shape[axis]:
                ok = False
        if ok:
            props_filtered[label] = bbox

    return props_filtered, len(props), len(props_filtered)


def generate_bordercore(image, border_thickness=1, n_erosions=1):
    """
    Generate border pixels from 3D binary masks.
    
    Parameters:
        image (np.ndarray): 3D binary image (0s and 255s, or really 0 and any other positive value).
        border_thickness (int): Thickness of the border region.
        n_erosions (int): Number of erosion iterations.
        
    Returns:
        np.ndarray: Image with cores labeled as 255 and borders labeled as 127.
    """
    # Ensure binary input (convert 255 to 1 for processing)
    binary_image = image > 0

    # Define 3D structuring elements
    erosion_kernel = ball(n_erosions)  # Spherical structuring element for erosion
    dilation_kernel = footprint_rectangle([2 * border_thickness + 1, 2 * border_thickness + 1, 2 * border_thickness + 1])  # Cubic structuring element for dilation

    # Apply 3D erosion
    eroded_image = binary_erosion(binary_image, erosion_kernel)

    # Apply 3D dilation
    dilated = dilation(eroded_image, dilation_kernel)

    # Assign labels: cores as 1, borders as 2, background as 0
    bordercore = np.zeros_like(image, dtype=np.uint8)
    bordercore[eroded_image] = 1  # Core pixels
    bordercore[dilated & ~eroded_image] = 2  # Border pixels

    return bordercore


def border_core2instance(border_core: np.ndarray, dtype: Type = np.uint16) -> Tuple[np.ndarray, int]:
    border_core_array = np.array(border_core)
    component_seg = cc3d.connected_components(border_core_array > 0).astype(dtype)
    instances = np.zeros_like(border_core, dtype=dtype)
    num_instances = 0
    props = {i: bbox for i, bbox in enumerate(cc3d.statistics(component_seg)["bounding_boxes"]) if i != 0}
    
    for label, bbox in tqdm(props.items(), desc="Border-Core2Instance"):
        filter_mask = component_seg[bbox] == label
        border_core_patch = copy.deepcopy(border_core[bbox])
        border_core_patch[filter_mask != 1] = 0
        instances_patch = border_core_component2instance_dilation(border_core_patch).astype(dtype)
        instances_patch[instances_patch > 0] += num_instances
        num_instances = max(num_instances, np.max(instances_patch))
        patch_labels = np.unique(instances_patch)
        patch_labels = patch_labels[patch_labels > 0]
        for patch_label in patch_labels:
            instances[bbox][instances_patch == patch_label] = patch_label
    return instances, num_instances

def border_core_component2instance_dilation(patch: np.ndarray, core_label: int = 2, border_label: int = 1) -> np.ndarray:
    '''The values used for core and border don't actually matter: 
       'num_instances = nd_label(patch == core_label, output=core_instances)' assigns unique instance labels where patch == core_label, regardless of whether core_label is 1, 255, or any other value.
       'border = patch == border_label' creates a binary mask (True for border, False elsewhere). Whether border_label is 2, 127, or any other value does not affect behavior.
       'dilated = dilation(core_instances, footprint_rectangle([3,3,3]))' expands the core_instances which are the instance labels (1-# of instances) and the background (0). Thus, all instances are "brighter" than the surrounding background (no border that is brighter/darker than core).
    '''
    core_instances = np.zeros_like(patch, dtype=np.uint16)
    num_instances = nd_label(patch == core_label, output=core_instances)
    if num_instances == 0:
        return patch
    # patch, core_instances, num_instances = remove_small_cores(patch, core_instances, core_label, border_label) # Weird effect on segmentation...
    # core_instances = np.zeros_like(patch, dtype=np.uint16)
    # num_instances = nd_label(patch == core_label, output=core_instances)
    # if num_instances == 0:
    #     return patch
    instances = copy.deepcopy(core_instances)
    border = patch == border_label
    while np.sum(border) > 0:
        dilated = dilation(core_instances, footprint_rectangle([3,3,3])) # Bottleneck when many instances
        dilated[patch == 0] = 0
        diff = (core_instances == 0) & (dilated != core_instances)
        instances[diff & border] = dilated[diff & border]
        border[diff] = 0
        core_instances = dilated
    return instances

def remove_small_cores(
    patch: np.ndarray, 
    core_instances: np.ndarray, 
    core_label: int, 
    border_label: int, 
    min_distance: float = 1, 
    min_ratio_threshold: float = 0.95, 
    max_distance: float = 3, 
    max_ratio_threshold: float = 0.0) -> Tuple[np.ndarray, np.ndarray, int]:

    distances = distance_transform_edt(patch == core_label)
    core_ids = np.unique(core_instances)
    core_ids_to_remove = []
    for core_id in core_ids:
        core_distances = distances[core_instances == core_id]
        num_min_distances = np.count_nonzero(core_distances <= min_distance)
        num_max_distances = np.count_nonzero(core_distances >= max_distance)
        num_core_voxels = np.count_nonzero(core_instances == core_id)
        min_ratio = num_min_distances / num_core_voxels
        max_ratio = num_max_distances / num_core_voxels
        if (min_ratio_threshold is None or min_ratio >= min_ratio_threshold) and (max_ratio_threshold is None or max_ratio <= max_ratio_threshold):
            core_ids_to_remove.append(core_id)
    num_cores = len(core_ids) - len(core_ids_to_remove)
    if len(core_ids_to_remove) > 0:
        target_values = np.zeros_like(core_ids_to_remove, dtype=int)
        shape = patch.shape
        core_instances = npi.remap(core_instances.flatten(), core_ids_to_remove, target_values).reshape(shape)
        patch[(patch == core_label) & (core_instances == 0)] = border_label

    return patch, core_instances, num_cores


# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('-i', "--input", required=True,
#                         help="Absolute input path to the base folder that contains the dataset structured in the form of the directories 'images' and 'instance_seg' and the file metadata.json.")
#     parser.add_argument('-o', "--output", required=True, help="Absolute output path to the preprocessed dataset directory.")
#     parser.add_argument('-n', "--name", required=False, type=str, default=None, nargs="+", help="(Optional) The name(s) without extension of the image(s) that should be used for training. Multiple names must be separated by spaces.")
#     parser.add_argument('-t', '--task', required=False, default=500, type=int, help="(Optional) The task id that should be assigned to this dataset.")
#     parser.add_argument('-z', '--zscore', default=(5850.29762143569, 7078.294543817302), required=False, type=float, nargs=2,
#                         help="(Optional) The target spacing in millimeters given as three numbers separate by spaces.")
#     parser.add_argument('-target_particle_size', default=60, required=False, type=int,
#                         help="(Optional) The target particle size in pixels given as three numbers separate by spaces.")
#     parser.add_argument('-target_spacing', default=0.1, required=False, type=float,
#                         help="(Optional) The target spacing in millimeters given as three numbers separate by spaces.")
#     parser.add_argument('-p', '--processes', required=False, default=None, type=int, help="(Optional) Number of processes to use for parallel processing. None to disable multiprocessing.")
#     parser.add_argument('-thickness', required=False, default=2, type=int, help="(Optional) Border thickness in pixel.")
#     parser.add_argument('--disable_gpu', required=False, default=False, action="store_true", help="(Optional) Disables use of the GPU for preprocessing.")
#     args = parser.parse_args()

#     parser = argparse.ArgumentParser(description="Preprocess a dataset for training a particle segmentation model.")

#     names = args.name

#     if names is None:
#         names = utils.load_filepaths(join(args.input, "images"), extension=".nii.gz", return_path=False, return_extension=False)

#     print("Samples: ", names)
#     print("Num samples: ", len(names))

#     dataset_name = "Task{}_ParticleSeg3D".format(str(args.task).zfill(3))

#     preprocess_all(args.input, names, args.output, args.target_spacing, args.target_particle_size, dataset_name, args.processes, args.thickness, not args.disable_gpu, args.zscore)


# if __name__ == '__main__':
#     main()


In [2]:
import os

def setup_paths(dir_location, is_original_data, output_to_cloud, task):
    if dir_location.lower() == 'internal':
        base_path = r'C:\Senior_Design'
    elif dir_location.lower() == 'external':
        base_path = r'D:\Senior_Design'
    elif dir_location.lower() == 'cloud':
        base_path = r'C:\Users\dchen\OneDrive - University of Connecticut\Courses\Year 4\Fall 2024\BME 4900 and 4910W (Kumavor)\Python\Files'
    elif dir_location.lower() == 'refine':
        base_path = r'D:\Darren\Files'
    else:
        raise ValueError('Invalid directory location type')
    
    base_input_path = os.path.join(base_path, 'database')
    if is_original_data:
        input_path = os.path.join(base_input_path, 'orignal_dataset', 'training', 'Task' + str(task))
    else:
        input_path = os.path.join(base_input_path, 'tablet_dataset', 'training', 'Task' + str(task))
    
    if not os.path.isdir(input_path):
        raise ValueError("This input path is not valid:\n" + input_path)

    if output_to_cloud:
        # base_path = r'C:\Users\dchen\OneDrive - University of Connecticut\Courses\Year 4\Fall 2024\BME 4900 and 4910W (Kumavor)\Python\Files'
        base_path = r'D:\Darren\OneDrive - University of Connecticut\Courses\Year 4\Fall 2024\BME 4900 and 4910W (Kumavor)\Python\Files'
    output_path = os.path.join(base_path, 'training', 'nnUNet_raw_data_base', 'nnUNet_raw_data')
    if not os.path.isdir(output_path):
        os.makedirs(output_path)

    print('Paths set')
    return input_path, output_path

In [None]:
import utils
from os.path import join

# Set argument values directly
input_path = "/absolute/path/to/dataset"  # Absolute path to the directory containing 'metadata.json', 'images', and either 'instance_seg' or 'border_core' directories
output_path = "/absolute/path/to/output"  # Absolute path where the preprocessed dataset will be stored
names = None  # List of image names (without extensions) to use for training; None to use all available images
task = '504'  # Task ID (int) assigned to the dataset
target_spacing = 0.1  # Target spacing in millimeters
target_particle_size = 60  # Target particle size in pixels
processes = None  # Number of processes for parallel processing; None disables multiprocessing
thickness = 4  # Border thickness in pixels (erosion only)
disable_gpu = False  # If True, disables GPU usage for preprocessing
binary = True # Used if the border-core representation is used instead of instance segmentation.
device = 1 # Value indicating which GPU to use.

# Paths
input_path, output_path = setup_paths('refine', False, False, task)

if names is None:
    names = utils.load_filepaths(join(input_path, "images"), extension=".nii.gz", return_path=False, return_extension=False)

print("Samples: ", names)
print("Num samples: ", len(names))

dataset_name = "Task{}_ParticleSeg3D".format(str(task).zfill(3))

preprocess_all(input_path, names, output_path, target_spacing, target_particle_size, dataset_name, processes, thickness, not disable_gpu, binary, device)


Paths set
Samples:  ['2_Tablet_Aug1', '2_Tablet_Aug2', '2_Tablet_Aug3', '2_Tablet_Aug4', '2_Tablet_Aug5', '4_GenericD12_Aug1', '4_GenericD12_Aug2', '4_GenericD12_Aug3', '4_GenericD12_Aug4', '4_GenericD12_Aug5', '5_ClaritinD12_Aug1', '5_ClaritinD12_Aug2', '5_ClaritinD12_Aug3', '5_ClaritinD12_Aug4', '5_ClaritinD12_Aug5']
Num samples:  15


  0%|          | 0/15 [00:00<?, ?it/s]

