# Functions

In [None]:
import numpy as np
import tifffile as tiff
import random
import scipy.ndimage
import os
from tqdm import tqdm

# Load a stack of TIFF images from a directory
def load_tiff_stack_from_dir(directory):
    '''Loads a stack of TIFF images from a directory into a 3D NumPy array.'''
    tiff_files = sorted([f for f in os.listdir(directory) if f.lower().endswith('.tiff') or f.lower().endswith('.tif')])
    
    if not tiff_files:
        raise FileNotFoundError('No TIFF files found in the directory.')

    stack = [tiff.imread(os.path.join(directory, f)) for f in tiff_files]
    return np.stack(stack, axis=0)  # Stack images along the first dimension (depth)

# Save dataset to a directory
def save_tiff_stack_to_dir(volume, output_dir):
    '''Saves a 3D NumPy array as individual TIFF images in a directory.'''
    os.makedirs(output_dir, exist_ok=True)
    
    for i in range(volume.shape[0]):
        tiff.imwrite(os.path.join(output_dir, f'slice_{i:03d}.tiff'), volume[i])

# Custom Mirror Transformations
def mirror(volume, mode):
    if mode == "horizontal":
        volume = np.flip(volume, axis=2)
    elif mode == "vertical":
        volume = np.flip(volume, axis=1)
    elif mode == "depth":
        volume = np.flip(volume, axis=0)
    else:
        raise ValueError("Mode must be 'horizontal', 'vertical', or 'depth'.")
    return volume

def rotate_90_3d(volume, num_rot, rot_axes):
    return np.rot90(volume, k=num_rot, axes=rot_axes)

# Shearing Crop
def shear_crop(volume, crop_factor, crop_dim='height'):
    d, h, w = volume.shape
    if crop_dim == 'height':
        crop_size = int(h * crop_factor / 2)
        volume = volume[:, crop_size:h - crop_size, :]
    elif crop_dim == 'width':
        crop_size = int(w * crop_factor / 2)
        volume = volume[:, :, crop_size:w - crop_size]
    elif crop_dim == 'both':
        crop_size_h = int(h * crop_factor / 2)
        crop_size_w = int(w * crop_factor / 2)
        volume = volume[:, crop_size_h:h - crop_size_h, crop_size_w:w - crop_size_w]
    else:
        raise ValueError("Crop dimension must be 'height', 'width', or 'both'")
    return volume

def random_erasing(volume, seed, num_cubes=100, min_size=50, max_size=75):
    """
    Applies random occlusions to a 3D volume by replacing small regions with noise.
    
    Parameters:
    - volume (numpy array): The 3D volume to augment.
    - seed (int): Random seed for reproducibility.
    - num_cubes (int): Number of erased regions.
    - min_size (int): Minimum size of the occluded region.
    - max_size (int): Maximum size of the occluded region.

    Returns:
    - Augmented volume with random cubes erased.
    """
    np.random.seed(seed)
    d, h, w = volume.shape
    mean_intensity = np.mean(volume)  # Compute mean intensity
    std_intensity = np.std(volume)  # Compute standard deviation for realistic noise

    for _ in range(num_cubes):
        size = np.random.randint(min_size, max_size + 1)
        x = np.random.randint(0, d - size + 1)
        y = np.random.randint(0, h - size + 1)
        z = np.random.randint(0, w - size + 1)

        # Replace occluded region with Gaussian noise centered around the mean intensity
        volume[x:x+size, y:y+size, z:z+size] = 0

    return volume

def scale(volume, scale_factor=1.3, axis='height', order=3):
    d, h, w = volume.shape
    if axis == 'height':
        zoom_factors = (1, scale_factor, 1)
    elif axis == 'width':
        zoom_factors = (1, 1, scale_factor)
    elif axis == 'depth':
        zoom_factors = (scale_factor, 1, 1)
    else:
        raise ValueError("Axis must be 'height', 'width', or 'depth'")
    return scipy.ndimage.zoom(volume, zoom=zoom_factors, order=order)


def zoom_fixed(volume, zoom_factor=1.3, order=3):
    d, h, w = volume.shape
    
    # Apply zoom
    zoomed = scipy.ndimage.zoom(volume, (1, zoom_factor, zoom_factor), order=order)

    # Compute new shape
    new_d, new_h, new_w = zoomed.shape  # Depth remains unchanged

    # Crop or pad along height
    if new_h > h:
        crop_h = (new_h - h) // 2
        zoomed = zoomed[:, crop_h:crop_h + h, :]
    elif new_h < h:
        pad_h = (h - new_h) // 2
        zoomed = np.pad(zoomed, ((0, 0), (pad_h, h - new_h - pad_h), (0, 0)), mode='constant')

    # Crop or pad along width
    if new_w > w:
        crop_w = (new_w - w) // 2
        zoomed = zoomed[:, :, crop_w:crop_w + w]
    elif new_w < w:
        pad_w = (w - new_w) // 2
        zoomed = np.pad(zoomed, ((0, 0), (0, 0), (pad_w, w - new_w - pad_w)), mode='constant')

    return zoomed


def augment_volume(volume, config, dtype, order, seed):
    np.random.seed(seed)
    if config.get('shear_crop', False):
        volume = shear_crop(volume, crop_factor=config.get('crop_factor', 0.3), crop_dim=config.get('shear_crop_dim', 'height'))
    if config.get('scale', False):
        volume = scale(volume, scale_factor=config.get('scale_factor', 1.3), axis=config.get('scale_axis', 'height'), order=order)
    if config.get('zoom_fixed', False):
        volume = zoom_fixed(volume, zoom_factor=config.get('zoom_factor', 1.2), order=order)
    if config.get('mirror', False):
        volume = mirror(volume, mode=config.get('mirror_mode', 'horizontal'))
    if config.get('rotate', False):
        volume = rotate_90_3d(volume, num_rot=config.get('num_rot', 1), rot_axes=config.get('rot_axes', (1,2)))
    if config.get('random_erasing', False):
        volume = random_erasing(volume, seed)
    
    return np.array(volume, dtype=dtype)

def process_and_save_tiff(input_dir, output_dir, config, dtype, order, seed=42):
    volume = load_tiff_stack_from_dir(input_dir)
    augmented_volume = augment_volume(volume, config, dtype, order, seed)
    save_tiff_stack_to_dir(augmented_volume, output_dir)

# Main

In [6]:
from tqdm import tqdm
import os

# Define augmentation configurations
config_Aug1 = {
        'mirror': True, # Apply left-right and top-bottom mirroring
        'mirror_mode': 'horizontal', # Axis for mirroring
        'rotate': True, # Apply 90-degree rotation
        'num_rot': 1, # Number of 90 degree rotations to perform
        'rot_axes': (0,1), # Denotes the plane that rotation occurs in
}
config_Aug2 = {
        'mirror': True, # Apply left-right and top-bottom mirroring
        'mirror_mode': 'horizontal', # Axis for mirroring
        'rotate': True, # Apply 90-degree rotation
        'num_rot': 1, # Number of 90 degree rotations to perform
        'rot_axes': (0,2), # Denotes the plane that rotation occurs in
}
config_Aug3 = {
        'mirror': True, # Apply left-right and top-bottom mirroring
        'mirror_mode': 'horizontal', # Axis for mirroring
        'rotate': True, # Apply 90-degree rotation
        'num_rot': 1, # Number of 90 degree rotations to perform
        'rot_axes': (1,2), # Denotes the plane that rotation occurs in
}
config_Aug4 = {
        'mirror': True, # Apply left-right and top-bottom mirroring
        'mirror_mode': 'vertical', # Axis for mirroring
        'rotate': True, # Apply 90-degree rotation
        'num_rot': 1, # Number of 90 degree rotations to perform
        'rot_axes': (0,1), # Denotes the plane that rotation occurs in
}
config_Aug5 = {
        'mirror': True, # Apply left-right and top-bottom mirroring
        'mirror_mode': 'vertical', # Axis for mirroring
        'rotate': True, # Apply 90-degree rotation
        'num_rot': 1, # Number of 90 degree rotations to perform
        'rot_axes': (0,2), # Denotes the plane that rotation occurs in
}

if __name__ == '__main__':
    img_names = ['2_Tablet', '4_GenericD12', '5_ClaritinD12']
    base_path = r'd:\Darren\Files\database\tablet_dataset'
    augmentations = ['Aug1', 'Aug2', 'Aug3', 'Aug4', 'Aug5']
    configs = [config_Aug1, config_Aug2, config_Aug3, config_Aug4, config_Aug5]

    for img_name in tqdm(img_names, desc="Processing images", position=0, leave=True):
        grayscale_path = os.path.join(base_path, 'grayscale', 'tiff', img_name)
        segmented_path = os.path.join(base_path, 'segmented', 'tiff', img_name)

        aug_config_pairs = list(zip(augmentations, configs))  # Ensure it's a list

        for aug, config in aug_config_pairs:
            print(f"Applying {aug} to {img_name}...")  # Print progress

            output_grayscale = f"{grayscale_path}_{aug}"
            output_segmented = f"{segmented_path}_{aug}"

            process_and_save_tiff(grayscale_path, output_grayscale, config, dtype=np.uint16, order=3)
            process_and_save_tiff(segmented_path, output_segmented, config, dtype=np.uint8, order=0)


Processing images:   0%|          | 0/3 [00:00<?, ?it/s]

Applying Aug1 to 2_Tablet...
Applying Aug2 to 2_Tablet...
Applying Aug3 to 2_Tablet...
Applying Aug4 to 2_Tablet...
Applying Aug5 to 2_Tablet...


Processing images:  33%|███▎      | 1/3 [01:10<02:21, 70.64s/it]

Applying Aug1 to 4_GenericD12...
Applying Aug2 to 4_GenericD12...
Applying Aug3 to 4_GenericD12...
Applying Aug4 to 4_GenericD12...
Applying Aug5 to 4_GenericD12...


Processing images:  67%|██████▋   | 2/3 [02:33<01:18, 78.05s/it]

Applying Aug1 to 5_ClaritinD12...
Applying Aug2 to 5_ClaritinD12...
Applying Aug3 to 5_ClaritinD12...
Applying Aug4 to 5_ClaritinD12...
Applying Aug5 to 5_ClaritinD12...


Processing images: 100%|██████████| 3/3 [04:05<00:00, 81.87s/it]
