## Resize

In [None]:
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

# Padding augmenter
pad_aug = iaa.PadToAspectRatio(
    1.0, # Keep aspect ratio
    position="center",
    pad_mode="constant",
    pad_cval=0
)

# Resizers for Image and Labels
resize_img = iaa.Resize(256, interpolation="cubic")
resize_mask = iaa.Resize(256, interpolation="nearest")

# Function to resize images
def image_resizer(images, random_state, parents, hooks):
    """
    Resizes the images using the resize_img augmenter.
    Args:
        images (list): A list of images to resize.
        random_state: Random state for augmentation.
        parents: Parents for augmentation.
        hooks: Hooks for augmentation.
    Returns:
        list: A list of resized images.
    """
    return [resize_img.augment_image(img) for img in images]

# Function to resize masks
def label_resizer(segmaps, random_state, parents, hooks):
    """
    Resizes the segmentation masks using the resize_mask augmenter.
    Args:
        segmaps (list): A list of SegmentationMapsOnImage objects to resize.
        random_state: Random state for augmentation.
        parents: Parents for augmentation.
        hooks: Hooks for augmentation.
    Returns:
        list: A list of resized SegmentationMapsOnImage objects.
    """
    new_segmaps = []
    for segmap in segmaps:
        new_arr = resize_mask.augment_image(segmap.arr) # Resize the mask array.
        new_segmaps.append(SegmentationMapsOnImage(new_arr, shape=new_arr.shape)) # Create a new SegmentationMapsOnImage object with the resized array.
    return new_segmaps

# Resizing augmenter
resize_aug = iaa.Sequential([
    pad_aug, # Apply padding to maintain aspect ratio.
    iaa.Lambda( # Apply lambda to resize images and masks.
        func_images=image_resizer, # Function to resize images.
        func_segmentation_maps=label_resizer, # Function to resize segmentation masks.
    )
])

## Rotation

In [None]:
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

# Applies rotation augmentation to both images and segmentation maps
rotation_aug = iaa.Sequential([
    iaa.Affine(
        rotate=(45, 315), # Rotation angle
        fit_output=True, # Maintain full rotated image
        mode="constant",
        cval=0,
        backend="cv2"
    ),
    resize_aug
])

## Random Cropping

In [None]:
import imgaug.augmenters as iaa
from imgaug.imgaug import SegmentationMapsOnImage

class CenterSquareCropAugmenter(iaa.Augmenter):
    """
    CenterSquareCropAugmenter crops images to a square shape, centered.
    Args:
        name (str, optional): Name of the augmenter. Defaults to None.
        deterministic (bool, optional): Whether the augmentation is deterministic. Defaults to False.
        random_state (None, optional): Random state. Defaults to None.
    """
    def __init__(self, name=None, deterministic=False, random_state=None):
        super(CenterSquareCropAugmenter, self).__init__(
            name=name, deterministic=deterministic, random_state=random_state)
        self.cropper = iaa.CropToAspectRatio(1.0, position="center") # Square cropper
       
    def _augment_images(self, images, random_state, parents, hooks):
        """
        Applies the center square crop to a list of images.
        Args:
            images (list of numpy.ndarray): List of images.
            random_state (numpy.random.RandomState): Random state.
            parents (imgaug.parameters.StochasticParameter): Parents.
            hooks (imgaug.hook.HooksImages): Hooks.
        Returns:
            list of numpy.ndarray: List of cropped images.
        """
        return [self.cropper.augment_image(img) for img in images]
    
    def _augment_segmentation_maps(self, segmaps, random_state, parents, hooks):
        """
        Applies the center square crop to segmentation maps.
        Args:
            segmaps (list of imgaug.imgaug.SegmentationMapsOnImage): List of segmentation maps.
            random_state (numpy.random.RandomState): Random state.
            parents (imgaug.parameters.StochasticParameter): Parents.
            hooks (imgaug.hook.HooksImages): Hooks.
        Returns:
            list of imgaug.imgaug.SegmentationMapsOnImage: List of cropped segmentation maps.
        """
        out = []
        for segmap in segmaps:
            cropped_arr = self.cropper.augment_image(segmap.arr)
            out.append(SegmentationMapsOnImage(cropped_arr, shape=cropped_arr.shape))
        return out
    
    def get_parameters(self):
        return []


class RandomSquareCropAugmenter(iaa.Augmenter):
    """
    RandomSquareCropAugmenter crops images to a square shape, randomly.
    Args:
        crop_factor (float, optional): Ratio of the smallest edge to use for the square. Defaults to 2/3.
        name (str, optional): Name of the augmenter. Defaults to None.
        deterministic (bool, optional): Whether the augmentation is deterministic. Defaults to False.
        random_state (None, optional): Random state. Defaults to None.
    """
    def __init__(self, crop_factor=2/3, name=None, deterministic=False, random_state=None):
        """
        crop_factor: Ratio of the smallest edge to use for the square.
        """
        super(RandomSquareCropAugmenter, self).__init__(
            name=name, deterministic=deterministic, random_state=random_state)
        self.crop_factor = crop_factor
    
    def _augment_images(self, images, random_state, parents, hooks):
        """
        Applies the random square crop to a list of images.
        Args:
            images (list of numpy.ndarray): List of images.
            random_state (numpy.random.RandomState): Random state.
            parents (imgaug.parameters.StochasticParameter): Parents.
            hooks (imgaug.hook.HooksImages): Hooks.
        Returns:
            list of numpy.ndarray: List of cropped images.
        """
        out_images = []
        for img in images:
            H, W = img.shape[:2]
            min_side = min(H, W)
            crop_size = int(min_side * self.crop_factor)
            max_x = W - crop_size
            max_y = H - crop_size
            x1 = random_state.randint(0, max_x + 1)
            y1 = random_state.randint(0, max_y + 1)
            cropped_img = img[y1: y1 + crop_size, x1: x1 + crop_size]
            out_images.append(cropped_img)
        return out_images

    def _augment_segmentation_maps(self, segmaps, random_state, parents, hooks):
        """
        Applies the random square crop to segmentation maps.
        Args:
            segmaps (list of imgaug.imgaug.SegmentationMapsOnImage): List of segmentation maps.
            random_state (numpy.random.RandomState): Random state.
            parents (imgaug.parameters.StochasticParameter): Parents.
            hooks (imgaug.hook.HooksImages): Hooks.
        Returns:
            list of imgaug.imgaug.SegmentationMapsOnImage: List of cropped segmentation maps.
        """
        out_segmaps = []
        for segmap in segmaps:
            arr = segmap.arr
            H, W = arr.shape[:2]
            min_side = min(H, W)
            crop_size = int(min_side * self.crop_factor)
            max_x = W - crop_size
            max_y = H - crop_size
            x1 = random_state.randint(0, max_x + 1)
            y1 = random_state.randint(0, max_y + 1)
            cropped_arr = arr[y1: y1 + crop_size, x1: x1 + crop_size]
            out_segmaps.append(SegmentationMapsOnImage(cropped_arr, shape=cropped_arr.shape))
        return out_segmaps
    
    def get_parameters(self):
        return [self.crop_factor]
    
# Center crop augmenter.
center_crop_aug = iaa.Sequential([CenterSquareCropAugmenter(), resize_aug])

# Random square crop augmenter.
random_crop_aug = iaa.Sequential([RandomSquareCropAugmenter(), resize_aug])

## Random masking

In [None]:
import numpy as np
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

# Define augmentation for masking images.
mask_im_aug = iaa.Sequential([
    iaa.CoarseDropout(p=0.15, size_percent=(1/50), random_state=2)
])

# Define augmentation for masking labels.
mask_label_aug = iaa.Sequential([
    iaa.CoarseDropout(p=0.15, size_percent=(1/50), random_state=2)
])

def random_masking_labels(segmaps, random_state, parents, hooks):
    """
    Applies random masking to segmentation maps.
    Args:
        segmaps: Input segmentation maps.
        random_state: Random state for augmentation.
        parents: Parent objects.
        hooks: Hooks for augmentation.
    Returns:
        List of augmented segmentation maps.
    """
    new_segmaps = []
    for segmap in segmaps:
        # Convert segmentation map array to uint8 for augmentation.
        segmap_arr_uint8 = segmap.arr.astype(np.uint8)
        # Apply mask_label_aug to the segmentation map.
        new_arr = mask_label_aug.augment_image(segmap_arr_uint8)
        # Create a new SegmentationMapsOnImage object with the augmented array.
        new_segmaps.append(SegmentationMapsOnImage(new_arr, shape=new_arr.shape))
    return new_segmaps

# Random Masking augmentation
masking_aug = iaa.Sequential([
    iaa.Lambda(
        func_images=lambda images, rs, parents, hooks: [mask_im_aug.augment_image(img) for img in images],
        func_segmentation_maps=random_masking_labels
    ),
    resize_aug # Apply resizing after masking
])

## Grayscale

In [None]:
# Grayscale augmentation
grayscale = iaa.Grayscale(alpha=1.0, from_colorspace="RGB")
grayscale_aug = iaa.Sequential([
    grayscale,
    resize_aug
])

## Laplace Noise

In [None]:
# Laplace Noise Augmentation
laplace = iaa.AdditiveLaplaceNoise(scale=(0.1*255, 0.3*255), per_channel=True)
laplace_aug = iaa.Sequential([
    laplace,
    resize_aug
])

## Blur

In [None]:
# Blur Augmentation
blur = iaa.AverageBlur(k=(12))
blur_aug = iaa.Sequential([
    blur,
    resize_aug
])

## Contrast

In [None]:
# Contrast Augmentation
contrast = iaa.LinearContrast((0.2, 0.6))
contrast_aug = iaa.Sequential([
    contrast,
    resize_aug
])

## Merge

In [None]:
from PIL import Image
import math
import os 
import numpy as np 
from utils import convert_rgb_label_to_classes


def combine_images_preserve_aspect_ratio(image1_path, image2_path, output_path=None, is_label=False):
    """
    Combines two images, preserving aspect ratio, centers on 256x256, then optionally converts to a 1-channel class map.
    If is_label is True, converts the final RGB image to a 1-channel class map
    using convert_rgb_label_to_classes before saving/returning.

    Args:
        image1_path (str): Path to the first image.
        image2_path (str): Path to the second image.
        output_path (str, optional): Path to save the final image. Defaults to None.
        is_label (bool, optional): If True, apply label conversion. Defaults to False.

    Returns:
        PIL.Image.Image: The final combined image (RGB or L mode).

    Raises:
        FileNotFoundError, ValueError, IOError, RuntimeError as before.
    """
    TARGET_DIMENSION = 256
    RESAMPLE_METHOD = Image.Resampling.NEAREST

    def load_image(path):
        """
        Loads an image from the given path and converts it to RGB mode.
        Handles RGBA, LA, and P modes by converting them to RGB to avoid issues.

        Args:
            path (str): The path to the image file.

        Returns:
            PIL.Image.Image: The loaded image in RGB mode.
        """
        img = Image.open(path)
        if img.mode == 'RGBA':
            # Create a new RGB image with a black background and paste the image, masking the alpha channel
            background = Image.new('RGB', img.size, (0, 0, 0))
            background.paste(img, mask=img.split()[-1])
            img = background
        elif img.mode == 'LA':
            # Convert LA to RGBA, create a new RGB image with a black background, and paste the image, masking the alpha channel
            rgba_img = img.convert('RGBA')
            background = Image.new('RGB', rgba_img.size, (0, 0, 0))
            background.paste(rgba_img, mask=rgba_img.split()[-1])
            img = background
        elif img.mode == 'P':
                # Convert P to RGBA, create a new RGB image with a black background, and paste the image, masking the alpha channel
                rgba_img = img.convert('RGBA')
                background = Image.new('RGB', rgba_img.size, (0, 0, 0))
                background.paste(rgba_img, mask=rgba_img.split()[-1])
                img = background
        return img.convert('RGB')


    img1 = load_image(image1_path)
    img2 = load_image(image2_path)

    # 2. Dimensions & Orientation
    w1, h1 = img1.size
    w2, h2 = img2.size
    
    def get_orientation(w, h): 
        """
        Determine the orientation of an image (portrait or landscape).

        Args:
            w (int): Width of the image.
            h (int): Height of the image.

        Returns:
            str: 'portrait' if the image is portrait, 'landscape' otherwise.
        """
        return 'portrait' if h > w else 'landscape'
    
    orientation1 = get_orientation(w1, h1)
    orientation2 = get_orientation(w2, h2)
    
    if orientation1 != orientation2:
        print(f" Mismatched orientations ({orientation1} vs {orientation2}) for {os.path.basename(image1_path)}, {os.path.basename(image2_path)}")
        return None
    orientation = orientation1

    # 3. Calculate Scale
    if orientation == 'portrait':
        total_original_major_dim = w1 + w2
        if total_original_major_dim == 0:
            return None
        scale = TARGET_DIMENSION / total_original_major_dim
    else: # landscape
        total_original_major_dim = h1 + h2
        if total_original_major_dim == 0:
            return None
        scale = TARGET_DIMENSION / total_original_major_dim

    # 4. Calculate Scaled Dimensions
    scaled_w1 = max(1, math.ceil(w1 * scale))
    scaled_h1 = max(1, math.ceil(h1 * scale))
    scaled_w2 = max(1, math.ceil(w2 * scale))
    scaled_h2 = max(1, math.ceil(h2 * scale))

    # 5. Adjust for Exact Fit
    final_w1, final_h1 = scaled_w1, scaled_h1
    final_w2, final_h2 = scaled_w2, scaled_h2
    if orientation == 'portrait':
        diff = (scaled_w1 + scaled_w2) - TARGET_DIMENSION
        if diff > 0:
            final_w1 -= diff if scaled_w1 >= scaled_w2 else 0
            final_w2 -= diff if scaled_w2 > scaled_w1 else 0
    else:
        diff = (scaled_h1 + scaled_h2) - TARGET_DIMENSION
        if diff > 0: 
            final_h1 -= diff if scaled_h1 >= scaled_h2 else 0
            final_h2 -= diff if scaled_h2 > scaled_h1 else 0
            
    final_w1, final_h1, final_w2, final_h2 = max(1, final_w1), max(1, final_h1), max(1, final_w2), max(1, final_h2)

    # 6. Resize Images

    img1_resized = img1.resize((final_w1, final_h1), RESAMPLE_METHOD)
    img2_resized = img2.resize((final_w2, final_h2), RESAMPLE_METHOD)


    # 7. Create Combined Strip
    if orientation == 'portrait':
        combined_w, combined_h = TARGET_DIMENSION, max(final_h1, final_h2)
        combined = Image.new('RGB', (combined_w, combined_h), (0, 0, 0))
        combined.paste(img1_resized, (0, 0))
        combined.paste(img2_resized, (final_w1, 0))
    else: # landscape
        combined_w, combined_h = max(final_w1, final_w2), TARGET_DIMENSION
        combined = Image.new('RGB', (combined_w, combined_h), (0, 0, 0))
        combined.paste(img1_resized, (0, 0))
        combined.paste(img2_resized, (0, final_h1))

    # 8. Create Final Canvas & Center
    final_img = Image.new('RGB', (TARGET_DIMENSION, TARGET_DIMENSION), (0, 0, 0))
    paste_x = (TARGET_DIMENSION - combined.width) // 2
    paste_y = (TARGET_DIMENSION - combined.height) // 2
    final_img.paste(combined, (paste_x, paste_y))

    # 9. Label Conversion
    if is_label:
        # Convert final PIL Image to NumPy array
        final_img_np = np.array(final_img)
        # Apply the RGB -> Class ID conversion
        label_map_1channel = convert_rgb_label_to_classes(final_img_np)
        # Convert the 1-channel NumPy array back to a PIL Image (mode 'L')
        final_img = Image.fromarray(label_map_1channel, mode='L')


    # 10. Save if output path is provided
    if output_path:
        final_img.save(output_path)

    return final_img

## Augmentation without merge

In [None]:
import os
import random
import imageio
import numpy as np
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import math

# Define a dictionary mapping augmenter names to their corresponding functions
augmenter_dict = {
    "rotation": rotation_aug,
    "center_crop": center_crop_aug,
    "random_crop": random_crop_aug,
    "masking": masking_aug,
    "grayscale": grayscale_aug,
    "laplace": laplace_aug,
    "blur": blur_aug,
    "contrast": contrast_aug
}

# Calculate the number of augmenters
num_augmenters = len(augmenter_dict) # Count based on the dictionary

# Configuration
folder_path = "Train/color" # Path to the folder containing color images
label_folder_path = "Train/label" # Path to the folder containing label images
save_color_dir = "astrain/color" # Directory to save augmented color images
save_label_dir = "astrain/label" # Directory to save augmented label images
majority_aug_factor = 1.5 # Augmentation factor to balance the dataset

# Create the output directories if they don't exist
os.makedirs(save_color_dir, exist_ok=True)
os.makedirs(save_label_dir, exist_ok=True)

# File Discovery and Classification
print("Scanning for image files...")
filenames = [
    f for f in os.listdir(folder_path)
    if f.lower().endswith(('.jpg', '.png'))
]
# This section assumes that the file names can be used to identify the species

def get_species(filename):
    """
    Extracts the species name from a filename.
    Args:
        filename (str): The name of the file.
    Returns:
        str: The species name.
    """
    base = os.path.splitext(filename)[0]
    parts = base.rsplit('_', 1)
    return parts[0] if len(parts) > 1 else base

# Define a set of cat species for classification
cat_species = {
    "Russian_Blue", "Siamese", "Sphynx", "Maine_Coon", "Abyssinian",
    "Bombay", "British_Shorthair", "Bengal", "Egyptian_Mau", "Persian",
    "Ragdoll", "Birman"
}

# Initialize lists to store cat and dog filenames
cat_files = []
dog_files = []
for fname in filenames:
    species = get_species(fname)
    name_no_ext = os.path.splitext(fname)[0]
    label_path_check = os.path.join(label_folder_path, name_no_ext + ".png")
    if species in cat_species:
        cat_files.append(name_no_ext)
    else:
        dog_files.append(name_no_ext)

# Get the number of cat and dog files
N_cat = len(cat_files)
N_dog = len(dog_files)

print(f"Initial counts: Cats = {N_cat}, Dogs = {N_dog}")

# Copy Original Files
print("Processing originals with resize augmentation...")
all_original_files = cat_files + dog_files
processed_count = 0

# Ensure destination directories exist
os.makedirs(save_color_dir, exist_ok=True)
os.makedirs(save_label_dir, exist_ok=True)

for fname in all_original_files:
    orig_color_path = os.path.join(folder_path, fname + ".jpg")
    orig_label_path = os.path.join(label_folder_path, fname + ".png")

    # Define destination paths (using original base name)
    dest_color_path = os.path.join(save_color_dir, fname + ".jpg")
    dest_label_path = os.path.join(save_label_dir, fname + ".png")

    # Read the input image and its label
    img = imageio.v2.imread(orig_color_path)
    label = imageio.v2.imread(orig_label_path)

    # Create a segmentation map object
    segmap = SegmentationMapsOnImage(label, shape=img.shape)

    # Apply the resize augmentation to both image and label map
    resized_img, resized_segmap = resize_aug(image=img, segmentation_maps=segmap)
    resized_label = resized_segmap.get_arr()

    if resized_img.ndim == 3 and resized_img.shape[2] == 4:
        resized_img = resized_img[..., :3] # RGBA to RGB

    resized_label = convert_rgb_label_to_classes(resized_label)

    # Ensure correct data types before saving
    resized_img = resized_img.astype(np.uint8)
    resized_label = resized_label.astype(np.uint8)

    # Save the processed (resized) images
    imageio.imwrite(dest_color_path, resized_img)
    imageio.imwrite(dest_label_path, resized_label)

    processed_count += 1

print(f"Processed and saved {processed_count} original image/label pairs using resize_aug.")


# --- Calculate Augmentation Needs ---
if N_cat == N_dog:
    target_final_count = round(N_dog * majority_aug_factor)
elif N_cat < N_dog:
    target_final_count = round(N_dog * majority_aug_factor)
else: # N_dog < N_cat
    target_final_count = round(N_cat * majority_aug_factor)

total_aug_cat_needed = max(0, target_final_count - N_cat)
total_aug_dog_needed = max(0, target_final_count - N_dog)

print(f"Target final count per class: {target_final_count}")
print(f"Total augmentations needed: Cats = {total_aug_cat_needed}, Dogs = {total_aug_dog_needed}")

num_cats_per_aug = math.ceil(total_aug_cat_needed / num_augmenters)
num_dogs_per_aug = math.ceil(total_aug_dog_needed / num_augmenters)

print(f"Will select approximately {num_cats_per_aug} cats and {num_dogs_per_aug} dogs per augmenter.")

num_selected_cats = 0
num_selected_dogs = 0

# Augmentation Loop
generated_aug_count = 0
for i, (aug_name, aug_object) in enumerate(augmenter_dict.items()):
    # Randomly select files for augmentation
    selected_cats = random.choices(cat_files, k=num_cats_per_aug) if N_cat > 0 and num_cats_per_aug > 0 else []
    selected_dogs = random.choices(dog_files, k=num_dogs_per_aug) if N_dog > 0 and num_dogs_per_aug > 0 else []
    selected_files = selected_cats + selected_dogs
    num_selected_cats += len(selected_cats)
    num_selected_dogs += len(selected_dogs)

    print(f"\nUsing augmenter '{aug_name}' ({i+1}/{num_augmenters}): processing {len(selected_files)} images ({len(selected_cats)} cats, {len(selected_dogs)} dogs)")

    processed_in_batch = 0
    for fname in selected_files:
        color_path = os.path.join(folder_path, fname + ".jpg")
        label_path = os.path.join(label_folder_path, fname + ".png")

        # Read images
        img = imageio.v2.imread(color_path)
        label = imageio.v2.imread(label_path)
        segmap = SegmentationMapsOnImage(label, shape=img.shape)

        # Apply the augmentation
        augmented_img, augmented_segmap = aug_object(image=img, segmentation_maps=segmap)
        augmented_label = augmented_segmap.get_arr()

        if augmented_img.ndim == 3 and augmented_img.shape[2] == 4:
                augmented_img = augmented_img[..., :3] # RGBA to RGB

        augmented_label = convert_rgb_label_to_classes(augmented_label)

        # Type casting
        augmented_label = augmented_label.astype(np.uint8)
        augmented_img = augmented_img.astype(np.uint8)

        out_color_path = os.path.join(save_color_dir, f"{fname}_{aug_name}_{processed_in_batch}.jpg")
        out_label_path = os.path.join(save_label_dir, f"{fname}_{aug_name}_{processed_in_batch}.png")

        # Save the augmented images
        imageio.imwrite(out_color_path, augmented_img)
        imageio.imwrite(out_label_path, augmented_label)

        generated_aug_count += 1
        processed_in_batch += 1

    print(f"Augmenter '{aug_name}' finished. Processed {processed_in_batch} images.")

## Augmentation with merge

In [None]:
import os
import random

SOURCE_COLOR_DIR = "Train/color"
SOURCE_LABEL_DIR = "Train/label"
DEST_COLOR_DIR = "astrain/color"
DEST_LABEL_DIR = "astrain/label"
NUM_COMBINATIONS_PER_TYPE = 126

cat_species = {
    "Russian_Blue", "Siamese", "Sphynx", "Maine_Coon", "Abyssinian",
    "Bombay", "British_Shorthair", "Bengal", "Egyptian_Mau", "Persian",
    "Ragdoll", "Birman"
}

def get_species(filename):
    """
    Extracts the species name from a filename.
    Args:
        filename (str): The name of the file.
    Returns:
        str: The species name or the base filename if no species is found.
    """
    base = os.path.splitext(filename)[0]
    parts = base.rsplit('_', 1)
    return parts[0] if len(parts) > 1 else base


# 1. Create destination directories
os.makedirs(DEST_COLOR_DIR, exist_ok=True)
os.makedirs(DEST_LABEL_DIR, exist_ok=True)

# 2. Scan source directory and classify files
all_files_in_color = [
    f for f in os.listdir(SOURCE_COLOR_DIR)
    if f.lower().endswith(('.jpg', '.png')) # Assuming color can be jpg or png
]


cat_files = []
dog_files = []

for fname_ext in all_files_in_color:
    fname_no_ext = os.path.splitext(fname_ext)[0]
    label_path_check = os.path.join(SOURCE_LABEL_DIR, fname_no_ext + ".png")


    species = get_species(fname_ext)
    if species in cat_species:
        cat_files.append(fname_no_ext)
    else:
        dog_files.append(fname_no_ext)

N_cat = len(cat_files)
N_dog = len(dog_files)

print(f"Found {N_cat} cat images with labels.")
print(f"Found {N_dog} dog images with labels.")

# Function to generate combinations for a specific type
def generate_combinations(combo_type, files1_list, files2_list, num_required, output_prefix):
    """
    Generates N combinations by selecting files from lists and calling combine_images.
    Prints the source files used for each successful combination.

    Args:
        combo_type (str): Description (e.g., "1 Cat + 1 Dog")
        files1_list (list): List of base filenames for the first image.
        files2_list (list): List of base filenames for the second image.
        num_required (int): Target number of successful combinations.
        output_prefix (str): Prefix for output filenames (e.g., "cat_dog").
    """

    combinations_done = 0
    attempts = 0
    max_attempts = num_required * 10

    generated_pairs = set()
    file1_base, file2_base = None, None

    while combinations_done < num_required and attempts < max_attempts:
        attempts += 1

        if files1_list is files2_list:
            # If combining from the same list, sample 2 unique files.
            if len(files1_list) < 2:
                break
            file1_base, file2_base = random.sample(files1_list, 2)
        else:
            # If combining from different lists, sample one from each.
            if not files1_list or not files2_list:
                break
            file1_base = random.choice(files1_list)
            file2_base = random.choice(files2_list)

        pair_key = tuple(sorted((file1_base, file2_base)))
        if pair_key in generated_pairs:
            # Skip if this pair has already been generated.
            continue

        # Construct paths
        img1_color_ext = ".jpg"
        img2_color_ext = ".jpg"
        img1_label_ext = ".png"
        img2_label_ext = ".png"

        img1_color_path = os.path.join(SOURCE_COLOR_DIR, file1_base + img1_color_ext)
        img1_label_path = os.path.join(SOURCE_LABEL_DIR, file1_base + img1_label_ext)
        img2_color_path = os.path.join(SOURCE_COLOR_DIR, file2_base + img2_color_ext)
        img2_label_path = os.path.join(SOURCE_LABEL_DIR, file2_base + img2_label_ext)

        # Define output paths
        output_base_name = f"{output_prefix}_{combinations_done}"
        output_color_path = os.path.join(DEST_COLOR_DIR, output_base_name + ".jpg")
        output_label_path = os.path.join(DEST_LABEL_DIR, output_base_name + ".png")

        # Combine color images
        combined_color = combine_images_preserve_aspect_ratio(img1_color_path, img2_color_path, output_color_path)

        # Combine label images
        combined_label = combine_images_preserve_aspect_ratio(img1_label_path, img2_label_path, output_label_path, True)
        
        print(f"\n  Generated: {output_base_name}.jpg/.png using [{file1_base}{img1_color_ext}, {file2_base}{img2_color_ext}]")

        combinations_done += 1
        generated_pairs.add(pair_key)


# 1. 1 Cat + 1 Dog
generate_combinations(
    combo_type="1 Cat + 1 Dog",
    files1_list=cat_files,
    files2_list=dog_files,
    num_required=NUM_COMBINATIONS_PER_TYPE,
    output_prefix="cat_dog"
)

# 2. 2 Cats
generate_combinations(
    combo_type="2 Cats",
    files1_list=cat_files,
    files2_list=cat_files,
    num_required=NUM_COMBINATIONS_PER_TYPE,
    output_prefix="cat_cat"
)

# 3. 2 Dogs
generate_combinations(
    combo_type="2 Dogs",
    files1_list=dog_files,
    files2_list=dog_files,
    num_required=NUM_COMBINATIONS_PER_TYPE,
    output_prefix="dog_dog"
)

print("\n--- Combination process finished. ---")

## Prompt Augmentation

In [None]:
import numpy as np
import random
import torch
import numpy as np
import random
import os
import time
import shutil 
from torchvision.io import read_image
from PIL import Image 
from utils.dataset import target_remap

def create_gaussian_heatmap(size=(256, 256), sigma=3.0):
    """
    Creates a 2D heatmap array with a Gaussian spot centered at a random pixel.

    Args:
        size (tuple): The (height, width) dimensions of the heatmap array.
        sigma (float): The standard deviation (spread) of the Gaussian function.
                       Larger sigma means a wider, smoother spot.

    Returns:
        numpy.ndarray: A 2D numpy array representing the heatmap (values typically 0-1).
        tuple: The (y, x) coordinates of the chosen center pixel.
    """
    height, width = size
    if height <= 0 or width <= 0:
        raise ValueError("Size dimensions must be positive integers.")
    if sigma <= 0:
        raise ValueError("Sigma must be positive.")

    # 1. Create a black canvas (array of zeros)
    heatmap = np.zeros((height, width), dtype=np.float32) # Use float for calculations

    # 2. Pick a random center pixel
    center_y = random.randint(0, height - 1)
    center_x = random.randint(0, width - 1)
    print(f"Selected center pixel (y, x): ({center_y}, {center_x})")

    # 3. Create coordinate grids
    y_coords, x_coords = np.indices((height, width))

    # 4. Calculate the squared Euclidean distance from the center for each pixel
    dist_sq = (x_coords - center_x)**2 + (y_coords - center_y)**2

    # 5. Calculate the Gaussian function
    heatmap = np.exp(-dist_sq / (2 * sigma**2))

    return heatmap, (center_y, center_x)


# --- Helper function for the selection process ---
def select_dominant_class(heatmap, remapped_mask):
    """
    Selects the dominant class in a mask based on heatmap scores.

    Args:
        heatmap (numpy.ndarray): The heatmap array.
        remapped_mask (numpy.ndarray): The remapped mask array.

    Returns:
        int: The selected class (0 if no class is dominant).
        dict: A dictionary of class scores.
    """
    class_scores = {}
    present_classes = np.unique(remapped_mask)
    target_classes = present_classes[present_classes > 0] # Classes 1, 2, 3

    if target_classes.size == 0: return 0, {}

    for class_val in target_classes:
        mask_pixels = (remapped_mask == class_val)
        if np.any(mask_pixels):
             score = np.sum(heatmap[mask_pixels])
             class_scores[class_val] = score
        else:
             class_scores[class_val] = 0

    if not class_scores or all(s < 1e-9 for s in class_scores.values()):
        selected_class = 0
    else:
        selected_class = max(class_scores, key=class_scores.get)

    return selected_class, class_scores


TRAIN_IMG_DIR   = "astrain/color"
TRAIN_LBL_DIR   = "astrain/label"

HEATMAP_SIGMA   = 3.0
MAX_ATTEMPTS    = 1000

PSTRAIN_BASE_DIR    = "pstrain"                # New base output directory
PSTRAIN_IMG_DIR     = os.path.join(PSTRAIN_BASE_DIR, "color") # For COPIED original images
PSTRAIN_HEATMAP_DIR = os.path.join(PSTRAIN_BASE_DIR, "point_prompt") # For heatmap IMAGES
PSTRAIN_LABEL_DIR   = os.path.join(PSTRAIN_BASE_DIR, "label") # For final label masks


start_time = time.time()

os.makedirs(PSTRAIN_IMG_DIR, exist_ok=True)
os.makedirs(PSTRAIN_HEATMAP_DIR, exist_ok=True)
os.makedirs(PSTRAIN_LABEL_DIR, exist_ok=True)
print(f"Reading original images from:   {os.path.abspath(TRAIN_IMG_DIR)}")
print(f"Reading labels from:            {os.path.abspath(TRAIN_LBL_DIR)}")
print("-" * 30)
print(f"Saving copied images to:        {os.path.abspath(PSTRAIN_IMG_DIR)}")
print(f"Saving heatmap images to:       {os.path.abspath(PSTRAIN_HEATMAP_DIR)}")
print(f"Saving final label masks to:    {os.path.abspath(PSTRAIN_LABEL_DIR)}")


all_label_files = os.listdir(TRAIN_LBL_DIR)
label_files = sorted([f for f in all_label_files if f.lower().endswith('.png') and not f.startswith('.')])


# --- Loop through all found label files ---
processed_count = 0
skipped_count = 0
error_count = 0
img_not_found_count = 0

total_files = len(label_files)
print(f"\nStarting processing for {total_files} label files...")

for i, label_filename in enumerate(label_files):
    img_name_base = os.path.splitext(label_filename)[0] # Get base name without extension
    label_filepath = os.path.join(TRAIN_LBL_DIR, label_filename)

    print(f"\nProcessing label {i+1}/{total_files}: {label_filename} (Base: {img_name_base})")

    # Find the corresponding original image file
    original_img_path = None
    original_img_ext = None
    try:
        found = False
        for img_file in os.listdir(TRAIN_IMG_DIR):
            if os.path.splitext(img_file)[0] == img_name_base:
                original_img_path = os.path.join(TRAIN_IMG_DIR, img_file)
                original_img_ext = os.path.splitext(img_file)[1] # Get extension (e.g., '.jpg')
                print(f"  Found corresponding image: {img_file}")
                found = True
                break
        if not found:
                print(f"  Skipping: Could not find corresponding image file for base name '{img_name_base}' in {TRAIN_IMG_DIR}")
                img_not_found_count += 1
                skipped_count += 1
                continue
    except Exception as e:
        print(f"!!! ERROR searching for image file for {label_filename}: {e}")
        error_count += 1
        continue


    try:
        # Load the label mask file 
        label_tensor_loaded = read_image(label_filepath)

        # Handle channel issues (ensure single channel)
        if label_tensor_loaded.shape[0] != 1:
            if label_tensor_loaded.shape[0] == 3:
                    label_tensor_loaded = label_tensor_loaded[0:1, :, :]
                    print(f"  Info: Label had 3 channels, took the first.")
            else:
                print(f"  Skipping: Label has unexpected shape {label_tensor_loaded.shape}, expected (1, H, W).")
                skipped_count += 1
                continue

        # Apply the first remap (255 -> 3)
        label_tensor_original = target_remap(label_tensor_loaded)

        # Process the Loaded Mask ONCE per sample
        label_squeezed = label_tensor_original.squeeze(0)
        mask_post_remap1 = label_squeezed.numpy().astype(np.uint8)
        mask_size = mask_post_remap1.shape

        # Apply the SECOND remapping (Swap 3->0, Add 1) -> Final classes 1, 2, 3
        mask_swapped = mask_post_remap1.copy()
        mask_swapped[mask_post_remap1 == 3] = 0
        remapped_mask_final = mask_swapped + 1
        final_present_classes = np.unique(remapped_mask_final)
        final_target_classes = final_present_classes[final_present_classes > 0]

        # Check if finding two distinct classes is possible
        if len(final_target_classes) < 2:
            print(f"  Skipping: Mask only contains {len(final_target_classes)} target class(es) {final_target_classes.tolist()} after remapping. Cannot select 2 distinct.")
            skipped_count += 1
            continue

        # Loop to find TWO distinct class selections for this sample
        selected_results = [] # List to store (selected_class, final_mask_array, heatmap_array)
        attempts = 0
        found_classes = set()

        while len(selected_results) < 2 and attempts < MAX_ATTEMPTS:
            attempts += 1
            heatmap, center_coords = create_gaussian_heatmap(size=mask_size, sigma=HEATMAP_SIGMA)
            current_selected_class, _ = select_dominant_class(heatmap, remapped_mask_final)

            if current_selected_class > 0 and current_selected_class not in found_classes:
                final_mask = np.zeros_like(remapped_mask_final, dtype=np.uint8)
                final_mask[remapped_mask_final == current_selected_class] = current_selected_class
                selected_results.append((current_selected_class, final_mask, heatmap))
                found_classes.add(current_selected_class)
                print(f"    Attempt {attempts}: Found distinct class {current_selected_class} at {center_coords}")


        if len(selected_results) == 2:
            print(f"  Successfully found two distinct classes.")
            sel_cls_1, fin_msk_1, heatmap_1 = selected_results[0]
            sel_cls_2, fin_msk_2, heatmap_2 = selected_results[1]

            # --- Define final output filenames (consistent naming) ---
            output_base_name_1 = f"{img_name_base}_1"
            output_base_name_2 = f"{img_name_base}_2"

            # Paths for triplet 1
            output_img1_path = os.path.join(PSTRAIN_IMG_DIR, f"{output_base_name_1}{original_img_ext}")
            output_heatmap1_path = os.path.join(PSTRAIN_HEATMAP_DIR, f"{output_base_name_1}.png")
            output_label1_path = os.path.join(PSTRAIN_LABEL_DIR, f"{output_base_name_1}.png")

            # Paths for triplet 2
            output_img2_path = os.path.join(PSTRAIN_IMG_DIR, f"{output_base_name_2}{original_img_ext}")
            output_heatmap2_path = os.path.join(PSTRAIN_HEATMAP_DIR, f"{output_base_name_2}.png")
            output_label2_path = os.path.join(PSTRAIN_LABEL_DIR, f"{output_base_name_2}.png")


            # Copy the original image twice
            shutil.copy2(original_img_path, output_img1_path)
            shutil.copy2(original_img_path, output_img2_path)
            print(f"    Copied original image to: {os.path.basename(output_img1_path)}")
            print(f"    Copied original image to: {os.path.basename(output_img2_path)}")

            # Save Heatmaps as PNG Images (Scaled 0-255)
            heatmap1_scaled = (heatmap_1 * 255).astype(np.uint8)
            heatmap2_scaled = (heatmap_2 * 255).astype(np.uint8)
            Image.fromarray(heatmap1_scaled, mode='L').save(output_heatmap1_path) # 'L' mode for grayscale
            Image.fromarray(heatmap2_scaled, mode='L').save(output_heatmap2_path)
            print(f"    Saved Heatmap 1: {os.path.basename(output_heatmap1_path)}")
            print(f"    Saved Heatmap 2: {os.path.basename(output_heatmap2_path)}")


            # Save Final Masks as PNG Images (already uint8)
            Image.fromarray(fin_msk_1).save(output_label1_path)
            Image.fromarray(fin_msk_2).save(output_label2_path)
            print(f"    Saved Label 1:   {os.path.basename(output_label1_path)}")
            print(f"    Saved Label 2:   {os.path.basename(output_label2_path)}")


            processed_count += 1 # Count original files that yielded 2 outputs
        else:
            print(f"  Skipping: Failed to find two distinct classes within {MAX_ATTEMPTS} attempts.")
            skipped_count += 1

    except FileNotFoundError:
            print(f"!!! ERROR processing label file {label_filename}: File not found (unexpected).")
            error_count += 1
    except Exception as e:
        print(f"!!! ERROR processing label file {label_filename}: {e}")
        import traceback
        traceback.print_exc()
        error_count += 1
        # Continue to the next file


# --- Final Summary ---
end_time = time.time()
total_time = end_time - start_time
print("\n" + "="*40)
print("Processing Complete.")
print(f"Output base directory:                          {os.path.abspath(PSTRAIN_BASE_DIR)}")
print("-" * 40)
print(f"Total label files found:                        {total_files}")
print(f"Successfully processed (2 triplets generated):  {processed_count}")
print(f"Skipped (due to various reasons):               {skipped_count}")
print(f"  - Skipped because original image not found:   {img_not_found_count}")
print(f"  - Skipped (other reasons, e.g., too few classes): {skipped_count - img_not_found_count}")
print(f"Errors during processing:                       {error_count}")
print("-" * 40)
print(f"Total files created in '{os.path.basename(PSTRAIN_IMG_DIR)}':      {processed_count * 2}")
print(f"Total files created in '{os.path.basename(PSTRAIN_HEATMAP_DIR)}':  {processed_count * 2}")
print(f"Total files created in '{os.path.basename(PSTRAIN_LABEL_DIR)}':   {processed_count * 2}")
print("-" * 40)
print(f"Total time: {total_time:.2f} seconds")
print("="*40)