In [46]:
import os
import shutil
import ntpath
import pickle
import numpy as np
import matplotlib.pyplot as plt

In [3]:
paths = {
    'FD-027': '/home/miguel/GI/1.5 - Synthetic Data Generation/diffuse-gen/diffuse-gen/image_samples_FD-027/diffgen-2025-07-13-00-54/styled_samples_400x256x256x4.pkl',
    'FD-029': '/home/miguel/GI/1.5 - Synthetic Data Generation/diffuse-gen/diffuse-gen/image_samples_FD-029/diffgen-2025-07-13-07-47/styled_samples_560x256x256x4.pkl',
    'FD-030': '',
    'FD-031': '/home/miguel/GI/1.5 - Synthetic Data Generation/diffuse-gen/diffuse-gen/image_samples_FD-031/diffgen-2025-07-13-17-25/styled_samples_860x256x256x4.pkl',
    'FD-032': ''
}

In [30]:
import cv2
import numpy as np

def is_cohesive_mask(mask_gray: np.ndarray,
                     thresh_method: str = 'otsu',
                     min_area: int = 50,
                     main_frac_thresh: float = 0.6,
                     max_extra_components: int = 1) -> bool:
    """
    Decide whether a grayscale mask is coherent (one big region) or incohesive
    (many small regions).

    Parameters
    ----------
    mask_gray : np.ndarray
        2D array, single-channel mask (uint8 or float in [0,1]).
    thresh_method : str
        'otsu' | 'adaptive' | 'fixed' — how to binarize.
    min_area : int
        Ignore components smaller than this (noise filter).
    main_frac_thresh : float
        Fraction of total mask area the largest component must exceed to be
        considered dominant.
    max_extra_components : int
        Allow at most this many additional (filtered) components.

    Returns
    -------
    bool
        True if “coherent” (one dominant region), False otherwise.
    """
    # ——— ensure uint8 [0,255] ———
    if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
        img = (mask_gray * 255).astype(np.uint8)
    else:
        img = mask_gray.astype(np.uint8)

    # ——— binarize ———
    if thresh_method == 'otsu':
        _, bw = cv2.threshold(img, 0, 255,
                              cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    elif thresh_method == 'adaptive':
        bw = cv2.adaptiveThreshold(img, 255,
                                   cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY,
                                   blockSize=11, C=2)
    else:  # 'fixed'
        _, bw = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)

    # ——— connected components ———
    n_labels, labels, stats, _ = cv2.connectedComponentsWithStats(bw,
                                                                  connectivity=8)
    # stats[:, cv2.CC_STAT_AREA] is pixel-area for each label;
    # skip label 0 (background)
    areas = stats[1:, cv2.CC_STAT_AREA]

    # ——— filter out tiny noise blobs ———
    areas = areas[areas >= min_area]
    if areas.size == 0:
        return False

    total = areas.sum()
    largest = areas.max()
    extras = areas.size - 1

    # ——— decision ———
    # coherent if one blob covers ≥ main_frac_thresh of all mask pixels,
    # and there are at most max_extra_components others
    if (largest / total) >= main_frac_thresh and extras <= max_extra_components:
        return True
    else:
        return False

In [52]:
original_dataset_folder_path = '/home/miguel/GI/1.5 - Synthetic Data Generation/Singan-Seg/unet_singan_augmented_datasets_2/augmented_dataset_expansion_factor_0'

In [56]:
dataset_folder = '../augmented_datasets_varied_expansion_factor'
shutil.rmtree(dataset_folder, ignore_errors=True)

In [57]:
os.makedirs(dataset_folder, exist_ok=True)

In [58]:
for expansion_factor in [
    0.5, 0.75, 1.0, 1.4
]:
    expansion_factor_folder = os.path.join(dataset_folder, f'expansion_factor_{expansion_factor}')
    augmentation_amount = {}
    for subject in list(paths.keys()):
        path = paths[subject]
        if path:
            cv_subject_folder = os.path.join(expansion_factor_folder, subject)
            train_folder = os.path.join(cv_subject_folder, 'train')
            os.makedirs(train_folder, exist_ok=True)
            val_folder = os.path.join(cv_subject_folder, 'val')
            os.makedirs(val_folder, exist_ok=True)
            
            # Copy over all the images for the CV-subject, originally. 
            original_cv_subject_folder = os.path.join(original_dataset_folder_path, subject)
            original_train_folder = os.path.join(original_cv_subject_folder, 'train')
            original_val_folder = os.path.join(original_cv_subject_folder, 'val')

            # Copy all the original images to the new folder
            for original_img_path in os.listdir(original_train_folder):
                original_img_full_path = os.path.join(original_train_folder, original_img_path)
                if original_img_path.endswith('.png'):
                    shutil.copy(original_img_full_path, train_folder)
            
            for original_img_path in os.listdir(original_val_folder):
                original_img_full_path = os.path.join(original_val_folder, original_img_path)
                if original_img_path.endswith('.png'):
                    shutil.copy(original_img_full_path, val_folder)

            # Load the augmented data
            num_augmented_images = 0
            subject_image_data = pickle.load(open(path, 'rb'))
            for original_img_path in subject_image_data.keys():
                subject_img_data_list = subject_image_data[original_img_path]
                for img in subject_img_data_list:
                    original_img_name = os.path.basename(original_img_path)
                    
                    synthetic_image_path = os.path.join(train_folder, f"{original_img_name.replace('-image.png', '-synthetic-image.png')}")
                    synthetic_mask_path = os.path.join(train_folder, f"{original_img_name.replace('-image.png', '-synthetic-mask.png')}")
                    # Ensure the augmented data is of a high quality. 
                    synthetic_mask = img[:, :, 3]
                    synthetic_image = img[:, :, :3]
                    if (is_cohesive_mask(synthetic_mask)):
                        # Export the mask to the location
                        cv2.imwrite(synthetic_mask_path, synthetic_mask)

                        # Export the image to the location
                        cv2.imwrite(synthetic_image_path, synthetic_image)

                        num_augmented_images += 1
            augmentation_amount[subject] = num_augmented_images
    print('augmentation_amount', augmentation_amount)


augmentation_amount {'FD-027': 50, 'FD-029': 465, 'FD-031': 303}
augmentation_amount {'FD-027': 50, 'FD-029': 465, 'FD-031': 303}
augmentation_amount {'FD-027': 50, 'FD-029': 465, 'FD-031': 303}
augmentation_amount {'FD-027': 50, 'FD-029': 465, 'FD-031': 303}


In [31]:
synthetic_mask = img[:, :, 3]
synthetic_image = img[:, :, :3]

In [32]:
is_cohesive_mask(synthetic_mask)

False

In [45]:
os.path.basename(original_img_path)

'FD-032-slice-53-image.png'

In [None]:
# 1. Recreate the original training dataset

# 2. Load the augmented data

In [14]:
len([(arr for arr in subject_image_data[key]) for key in subject_image_data.keys()])

NameError: name 'key' is not defined

In [15]:
[(arr for arr in subject_image_data[key]) for key in subject_image_data.keys()]

[<generator object <genexpr> at 0x78f6c0b96b00>,
 <generator object <genexpr> at 0x78f6c0b96800>,
 <generator object <genexpr> at 0x78f6c0b96980>,
 <generator object <genexpr> at 0x78f6c0b96d40>,
 <generator object <genexpr> at 0x78f6c0b96ec0>,
 <generator object <genexpr> at 0x78f6c0b96e00>,
 <generator object <genexpr> at 0x78f6c0b97100>,
 <generator object <genexpr> at 0x78f6c0b96f80>,
 <generator object <genexpr> at 0x78f6c0b971c0>,
 <generator object <genexpr> at 0x78f6c0b97040>,
 <generator object <genexpr> at 0x78f6c0b97280>,
 <generator object <genexpr> at 0x78f6c0b97340>,
 <generator object <genexpr> at 0x78f6c0b97400>,
 <generator object <genexpr> at 0x78f6c0b974c0>,
 <generator object <genexpr> at 0x78f6c0b97700>,
 <generator object <genexpr> at 0x78f6c0b97580>,
 <generator object <genexpr> at 0x78f6c0b97640>,
 <generator object <genexpr> at 0x78f6c0b97940>,
 <generator object <genexpr> at 0x78f6c0b97a00>,
 <generator object <genexpr> at 0x78f6c0b97ac0>,
 <generator object <