## Resize

In [None]:
import imgaug.augmenters as iaa
import imgaug as ia
import imageio
import numpy as np
import matplotlib.pyplot as plt
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

img_path = "Abyssinian_1"
img_dir = "Train/color/"
label_dir = "Train/label/"
image = imageio.imread(f"{img_dir}{img_path}.jpg")
mask = imageio.imread(f"{label_dir}{img_path}.png")[:,:,0]


# Define augmentation pipeline
im_resize_padding = iaa.Sequential([
    # Step 1: Pad to 1:1 aspect ratio (square) before resizing
    iaa.PadToAspectRatio(
        1.0,                     # Target aspect ratio (width/height)
        position="center",       # Center the image during padding
        pad_mode="constant",     # Pad with black (0) or use "edge"/"reflect"
        pad_cval=0               # Value used for padding
    ),
    # Step 2: Resize to 512x512 (now safe, aspect ratio is 1:1)
    iaa.Resize(
        512
    )
])

im_resize = iaa.Resize(512)

# Define augmentation pipeline
label_resize_padding = iaa.Sequential([
    # Step 1: Pad to 1:1 aspect ratio (square) before resizing
    iaa.PadToAspectRatio(
        1.0,                     # Target aspect ratio (width/height)
        position="center",       # Center the image during padding
        pad_mode="constant",     # Pad with black (0) or use "edge"/"reflect"
        pad_cval=0               # Value used for padding
    ),
    # Step 2: Resize to 512x512 (now safe, aspect ratio is 1:1)
    iaa.Resize(
        512,                     # Target size
        interpolation="nearest"   # For images ("nearest" for masks)
    )
])

label_resize = iaa.Resize(512, interpolation="nearest")



# Apply to image and mask
resized_image = im_resize_padding.augment_image(image)
resized_mask = label_resize_padding.augment_image(mask)  # Remove channel dim (back to H, W)

plt.imshow(resized_image)
plt.show()



# Define basic augmenters:
# 1. pad_aug: Pads the image/mask to achieve a 1:1 aspect ratio (square)
pad_aug = iaa.PadToAspectRatio(
    1.0,
    position="center",
    pad_mode="constant",
    pad_cval=0
)

# 2. resize augmenters for images and masks:
resize_img = iaa.Resize(512, interpolation="cubic")
resize_mask = iaa.Resize(512, interpolation="nearest")

# Define lambda functions to apply different resize augmenters.
# They must accept four parameters: (data, random_state, parents, hooks)
def tmp1(images, random_state, parents, hooks):
    # Process each image in the list using the image resizer.
    return [resize_img.augment_image(img) for img in images]

def tmp2(segmaps, random_state, parents, hooks):
    new_segmaps = []
    for segmap in segmaps:
        # Resize the mask array with the nearest-neighbor method.
        new_arr = resize_mask.augment_image(segmap.arr)
        # Construct a new segmentation map with the resized array.
        new_segmaps.append(SegmentationMapsOnImage(new_arr, shape=new_arr.shape))
    return new_segmaps

# Concatenate the augmenters into one sequence.
resize_aug = iaa.Sequential([
    pad_aug,
    iaa.Lambda(
        func_images=tmp1,
        func_segmentation_maps=tmp2,
    )
])

# When calling the augmenter, pass the segmentation maps as a single object.
segmap_obj = SegmentationMapsOnImage(mask, shape=image.shape)
aug_image, aug_segmap = resize_aug(image=image, segmentation_maps=segmap_obj)

# Extract the processed mask as a numpy array.
processed_mask = aug_segmap.get_arr()

# Visualize the results.
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(aug_image)
plt.title("Augmented Image")

plt.subplot(1, 2, 2)
plt.imshow(processed_mask)
plt.title("Augmented Mask")
plt.show()

In [None]:
# import os
# import glob
# import imageio
# import imgaug.augmenters as iaa
# from tqdm import tqdm
# from PIL import Image
# import numpy as np


# # Define augmentation pipelines
# im_resize_padding = iaa.Sequential([
#     iaa.PadToAspectRatio(1.0, position="center", pad_mode="constant", pad_cval=0),
#     iaa.Resize(512)
# ])

# label_resize_padding = iaa.Sequential([
#     iaa.PadToAspectRatio(1.0, position="center", pad_mode="constant", pad_cval=0),
#     iaa.Resize(512, interpolation="nearest")
# ])

# # Create output directories
# os.makedirs("atrain/color", exist_ok=True)
# os.makedirs("atrain/label", exist_ok=True)

# # Get list of files and process with progress bar
# color_files = glob.glob("Train/color/*.jpg")
# for color_path in tqdm(color_files, desc="Processing images", unit="img"):
#     base_name = os.path.splitext(os.path.basename(color_path))[0]
    
#     # Read and process image
#     image = Image.open(color_path).convert('RGB')  # Force RGB mode
#     image = np.array(image)
    
#     # Apply augmentation
#     resized_image = im_resize_padding.augment_image(image)
    
#     # Convert to PIL Image for safe saving
#     pil_image = Image.fromarray(resized_image.astype(np.uint8))
    
#     # Ensure RGB mode (in case augmentation changed it)
#     if pil_image.mode == 'RGBA':
#         pil_image = pil_image.convert('RGB')
    
#     # Save with JPEG format
#     pil_image.save(f"atrain/color/{base_name}.jpg", "JPEG", quality=95)
    
#     # Process corresponding label
#     label_path = f"Train/label/{base_name}.png"
#     if os.path.exists(label_path):
#         mask = imageio.imread(label_path)
#         mask = mask[:, :, 0]  # Remove alpha channel if present
        
#         # Apply label augmentation
#         resized_mask = label_resize_padding.augment_image(mask)
        
#         # Save mask as PNG
#         Image.fromarray(resized_mask.astype(np.uint8)).save(
#             f"atrain/label/{base_name}.png"
#         )
#     else:
#         print(f"\nMissing label for: {base_name}")

## Rotation

In [None]:
# import imgaug.augmenters as iaa
# from imgaug.augmentables.segmaps import SegmentationMapsOnImage

# # Convert mask to SegmentationMapsOnImage object
# segmap = SegmentationMapsOnImage(mask, shape=image.shape)

# # Define augmentation with black padding
# aug = iaa.Affine(
#     rotate=(45, 315),
#     order=3,          # Image interpolation (cubic)
#     mode="constant",  # Image padding mode
#     cval=0,           # Image padding value (black)
#     backend="cv2"
# )

# # Apply augmentation
# augmented_image, augmented_segmap = aug(
#     image=image,
#     segmentation_maps=segmap
# )

# # Extract augmented mask
# augmented_mask = augmented_segmap.get_arr()

# plt.imshow(augmented_image)
# plt.show()
# plt.imshow(augmented_mask)
# plt.show()

In [None]:
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import matplotlib.pyplot as plt
import imageio

rotation_aug = iaa.Sequential([
    # 1. Rotate with full content preservation
    iaa.Affine(
        rotate=(45, 315),
        fit_output=True,  # Maintain full rotated image
        mode="constant",
        cval=0,
        backend="cv2"
    ),
    # 2. Scale down to fit within 512x512 while preserving aspect ratio
    # iaa.Resize({"longer-side": 512}, interpolation="cubic"),
    # # 3. Pad to exact size with black borders
    # iaa.PadToSquare(
    #     position="uniform",
    #     pad_mode="constant",
    #     pad_cval=0
    # ),
    resize_aug
])

# Load sample data
image = imageio.imread("Train/color/Abyssinian_1.jpg")
mask = imageio.imread("Train/label/Abyssinian_1.png")[:, :, 0]  # Remove alpha

# Apply to image and mask
segmap = SegmentationMapsOnImage(mask, shape=image.shape)
aug_image, aug_segmap = rotation_aug(
    image=image,
    segmentation_maps=segmap
)
aug_mask = aug_segmap.get_arr()

# Plot results
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(aug_image)
ax[0].set_title('Rotated Image with Padding')
ax[1].imshow(aug_mask, cmap='jet')
ax[1].set_title('Rotated Mask')
plt.show()

## Random Cropping

In [None]:
# import imgaug as ia
# import imgaug.augmenters as iaa
# import numpy as np

# def random_square_crop_imgaug(image, mask, crop_biggest_square=False, random_state=None):
#     """
#     Crops a square from the image and a corresponding label mask
#     using imgaug's built-in functions.

#     If crop_biggest_square is True, it will crop the biggest possible square
#     from the center of the image. Otherwise, it will crop a random square
#     with a side length of 2/3 of the smallest edge of the image.

#     Args:
#         image (np.ndarray): The input image (H, W, C).
#         mask (np.ndarray): The label mask (H, W).
#         crop_biggest_square (bool): Whether to crop the biggest square
#             from the center of the image. Defaults to False.
#         random_state (None or int or imgaug.random.RNG or numpy.random.RandomState, optional):
#             Random state to use for random operations. Defaults to None.

#     Returns:
#         tuple: A tuple containing the cropped image and the cropped mask.
#     """

#     if random_state is None:
#         random_state = np.random.RandomState()  # Use numpy's RandomState by default
#     elif isinstance(random_state, int):
#         random_state = np.random.RandomState(random_state)
#     elif isinstance(random_state, ia.random.RNG):
#         pass  # Use the provided imgaug RNG directly
#     elif isinstance(random_state, np.random.RandomState):
#         pass  # Use the provided numpy RandomState directly
#     else:
#         raise ValueError("Invalid random_state.  Must be None, int, imgaug.random.RNG, or numpy.random.RandomState.")


#     # Convert numpy RandomState to imgaug.random.RNG if needed
#     if isinstance(random_state, np.random.RandomState):
#         random_state = ia.random.RNG(random_state.randint(0, 10**6))

#     height, width = image.shape[:2]

#     if crop_biggest_square:
#         cropper = iaa.CropToAspectRatio(1, "center")
#         cropped_image = cropper.augment_image(image)
#         cropped_mask = cropper.augment_image(mask)
#     else:
#         # Crop a random square with a side length of 2/3 of the smallest edge
#         min_side = min(height, width)
#         crop_size = int(min_side * (2/3))

#         # Determine maximum possible top-left corner coordinates for the crop
#         max_x = width - crop_size
#         max_y = height - crop_size

#         # Generate random top-left corner coordinates
#         x1 = random_state.randint(0, max_x + 1)
#         y1 = random_state.randint(0, max_y + 1)

#         # Create a bounding box object representing the crop
#         bbox = ia.BoundingBox(x1=x1, y1=y1, x2=x1 + crop_size, y2=y1 + crop_size)
#         bbs = ia.BoundingBoxesOnImage([bbox], shape=image.shape)

#         # Crop the image and mask based on the bounding box
#         cropper = iaa.CropToFixedSize(width=crop_size, height=crop_size) # Crop and resize to the crop size
#         cropped_image = cropper.augment_image(image[bbox.y1:bbox.y2, bbox.x1:bbox.x2])
#         cropped_mask = cropper.augment_image(mask[bbox.y1:bbox.y2, bbox.x1:bbox.x2])


#     return cropped_image, cropped_mask


# # Crop the image and mask using the imgaug function
# cropped_image, cropped_mask = random_square_crop_imgaug(image, mask)

# plt.imshow(cropped_image)
# plt.show()
# plt.imshow(cropped_mask)
# plt.show()

In [None]:
class CenterSquareCropAugmenter(iaa.Augmenter):
    def __init__(self, name=None, deterministic=False, random_state=None):
        super(CenterSquareCropAugmenter, self).__init__(
            name=name, deterministic=deterministic, random_state=random_state)
        # We use the built-in CropToAspectRatio augmenter.
        self.cropper = iaa.CropToAspectRatio(1.0, position="center")
    
    def _augment_images(self, images, random_state, parents, hooks):
        # Apply the center crop on each image.
        return [self.cropper.augment_image(img) for img in images]
    
    def _augment_segmentation_maps(self, segmaps, random_state, parents, hooks):
        # For each segmentation map, crop its underlying array then rebuild the object.
        out = []
        for segmap in segmaps:
            cropped_arr = self.cropper.augment_image(segmap.arr)
            out.append(SegmentationMapsOnImage(cropped_arr, shape=cropped_arr.shape))
        return out
    
    def get_parameters(self):
        return []


# Augmenter that crops a random square with side length = 2/3 of the smallest edge.
class RandomSquareCropAugmenter(iaa.Augmenter):
    def __init__(self, crop_factor=2/3, name=None, deterministic=False, random_state=None):
        """
        crop_factor: Ratio of the smallest edge to use for the square.
        """
        super(RandomSquareCropAugmenter, self).__init__(
            name=name, deterministic=deterministic, random_state=random_state)
        self.crop_factor = crop_factor
    
    def _augment_images(self, images, random_state, parents, hooks):
        # Process each image independently.
        out_images = []
        for img in images:
            H, W = img.shape[:2]
            min_side = min(H, W)
            crop_size = int(min_side * self.crop_factor)
            # Determine the range of valid top-left x and y coordinates.
            max_x = W - crop_size
            max_y = H - crop_size
            # Get a random top-left corner
            x1 = random_state.randint(0, max_x + 1)
            y1 = random_state.randint(0, max_y + 1)
            cropped_img = img[y1: y1 + crop_size, x1: x1 + crop_size]
            out_images.append(cropped_img)
        return out_images

    def _augment_segmentation_maps(self, segmaps, random_state, parents, hooks):
        # Process each segmentation map the same way.
        out_segmaps = []
        for segmap in segmaps:
            arr = segmap.arr
            H, W = arr.shape[:2]
            min_side = min(H, W)
            crop_size = int(min_side * self.crop_factor)
            max_x = W - crop_size
            max_y = H - crop_size
            x1 = random_state.randint(0, max_x + 1)
            y1 = random_state.randint(0, max_y + 1)
            cropped_arr = arr[y1: y1 + crop_size, x1: x1 + crop_size]
            out_segmaps.append(SegmentationMapsOnImage(cropped_arr, shape=cropped_arr.shape))
        return out_segmaps
    
    def get_parameters(self):
        return [self.crop_factor]
    
image = imageio.imread("Train/color/Abyssinian_10.jpg")  # (H, W, 3)
mask = imageio.imread("Train/label/Abyssinian_10.png")      # (H, W) or (H, W, C)
maskaug = SegmentationMapsOnImage(mask, shape=mask.shape)

# Option 1: Center crop augmenter.
center_crop_aug = iaa.Sequential([CenterSquareCropAugmenter(), resize_aug])

aug_img, aug_segmap_obj = center_crop_aug(image=image, segmentation_maps=maskaug)
aug_segmap = aug_segmap_obj.get_arr()

# Option 2: Random square crop augmenter.
random_crop_aug = iaa.Sequential([RandomSquareCropAugmenter(), resize_aug])
# To help alignment between an image/mask pair, process them one pair at a time in deterministic mode:
random_crop_aug_det = random_crop_aug.to_deterministic()

aug_img2 = random_crop_aug_det(image=image)
aug_segmap2 = random_crop_aug_det(segmentation_maps=maskaug).get_arr()

# Show results
fig, ax = plt.subplots(2, 2, figsize=(10, 10))

ax[0, 0].imshow(image)
ax[0, 0].set_title("Original Image")
ax[0, 1].imshow(mask, cmap="gray")
ax[0, 1].set_title("Original Mask")

ax[1, 0].imshow(aug_img)
ax[1, 0].set_title("Center Crop - Image")
ax[1, 1].imshow(aug_segmap, cmap="gray")
ax[1, 1].set_title("Center Crop - Mask")

plt.tight_layout()
plt.show()

# For random crop results, you can plot similarly:
fig, ax = plt.subplots(2, 2, figsize=(10, 10))

ax[0, 0].imshow(image)
ax[0, 0].set_title("Original Image")
ax[0, 1].imshow(mask, cmap="gray")
ax[0, 1].set_title("Original Mask")

ax[1, 0].imshow(aug_img2)
ax[1, 0].set_title("Random Crop - Image")
ax[1, 1].imshow(aug_segmap2, cmap="gray")
ax[1, 1].set_title("Random Crop - Mask")

plt.tight_layout()
plt.show()



## Random masking

In [None]:
import numpy as np
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imageio
import matplotlib.pyplot as plt

# Define individual augmenters for images and masks.
mask_im_aug = iaa.Sequential([
    iaa.CoarseDropout(p=0.15, size_percent=(1/50), random_state=2)
])
mask_label_aug = iaa.Sequential([
    iaa.CoarseDropout(p=0.15, size_percent=(1/50), random_state=2)
])

def tmp3(segmaps, random_state, parents, hooks):
    new_segmaps = []
    for segmap in segmaps:
        # Convert segmentation map array to uint8 so that it is accepted by imgaug
        segmap_arr_uint8 = segmap.arr.astype(np.uint8)
        # Augment the segmentation map using the label augmenter
        new_arr = mask_label_aug.augment_image(segmap_arr_uint8)
        # Construct a new segmentation map with the augmented array
        new_segmaps.append(SegmentationMapsOnImage(new_arr, shape=new_arr.shape))
    return new_segmaps

# Define augmentation pipeline using Lambda to process images and segmentation maps.
masking_aug = iaa.Sequential([
    iaa.Lambda(
        # Process images (if already uint8, conversion is not required)
        func_images=lambda images, rs, parents, hooks: [mask_im_aug.augment_image(img) for img in images],
        # Process segmentation maps via the helper function
        func_segmentation_maps=tmp3
    ),
    resize_aug
])

# List of file names (without extension)
files = ["Abyssinian_10", "Abyssinian_11", "Abyssinian_12",
         "Abyssinian_13", "Abyssinian_14", "Abyssinian_15"]

images = []
masks = []
for file in files:
    # Read the image (assuming images are uint8)
    image = imageio.imread(f"Train/color/{file}.jpg")

    # Read the mask and cast to uint8 to avoid int32 issues.
    mask = imageio.imread(f"Train/label/{file}.png").astype(np.uint8)
    
    # Wrap the mask in a SegmentationMapsOnImage using its shape.
    segmap = SegmentationMapsOnImage(mask, shape=mask.shape)
    
    # Apply the augmentation. Note: segmentation maps are passed as a single object.
    aug_image, aug_segmap = masking_aug(image=image, segmentation_maps=segmap)
    
    images.append(aug_image)
    masks.append(aug_segmap.get_arr())

# Visualize results
fig, axes = plt.subplots(len(images), 2, figsize=(10, 5 * len(images)))
for i, (img, mask) in enumerate(zip(images, masks)):
    # Plot image
    axes[i, 0].imshow(img)
    axes[i, 0].axis('off')
    axes[i, 0].set_title(f'Image: {files[i]}')
    
    # Plot mask (convert as needed if single-channel or with an alpha channel)
    if mask.ndim == 3 and mask.shape[2] == 4:  # Remove alpha channel if present
        mask = mask[..., :3]
    if mask.ndim == 3 and mask.shape[2] == 1:
        mask = mask.squeeze()
        
    axes[i, 1].imshow(mask, cmap='jet' if mask.ndim == 2 else None)
    axes[i, 1].axis('off')
    axes[i, 1].set_title(f'Mask: {files[i]}')

plt.tight_layout()
plt.show()

## Split

In [None]:
# import os
# import shutil
# import random

# # Set seed for reproducibility
# random.seed(42)

# # Create train and validation directories
# os.makedirs('Train/color', exist_ok=True)
# os.makedirs('Train/label', exist_ok=True)
# os.makedirs('Val/color', exist_ok=True)
# os.makedirs('Val/label', exist_ok=True)

# image_dir = 'TrainVal/color'
# label_dir = 'TrainVal/label'

# species_files = {}

# # Collect image files with corresponding labels, grouped by species
# for filename in os.listdir(image_dir):
#     if filename.endswith('.jpg'):
#         label_filename = filename.replace('.jpg', '.png')
#         label_path = os.path.join(label_dir, label_filename)
#         if not os.path.exists(label_path):
#             print(f"Skipping {filename} (label not found)")
#             continue
#         species = filename.split('_')[0]
#         if species not in species_files:
#             species_files[species] = []
#         species_files[species].append(filename)

# # Process each species to split and move files
# for species, files in species_files.items():
#     random.shuffle(files)
#     split_idx = int(0.8 * len(files))
#     train_files = files[:split_idx]
#     val_files = files[split_idx:]
    
#     # Move training data
#     for file in train_files:
#         # Move image
#         src_img = os.path.join(image_dir, file)
#         dst_img = os.path.join('Train/color', file)
#         shutil.move(src_img, dst_img)
#         # Move label
#         label_file = file.replace('.jpg', '.png')
#         src_label = os.path.join(label_dir, label_file)
#         dst_label = os.path.join('Train/label', label_file)
#         shutil.move(src_label, dst_label)
    
#     # Move validation data
#     for file in val_files:
#         src_img = os.path.join(image_dir, file)
#         dst_img = os.path.join('Val/color', file)
#         shutil.move(src_img, dst_img)
#         # Move label
#         label_file = file.replace('.jpg', '.png')
#         src_label = os.path.join(label_dir, label_file)
#         dst_label = os.path.join('Val/label', label_file)
#         shutil.move(src_label, dst_label)

# print("Dataset split into Train and Val folders successfully.")

## Grayscale

In [None]:
grayscale_aug_2 = iaa.Grayscale(alpha=1.0, from_colorspace="RGB")

grayscale_aug = iaa.Sequential([
    grayscale_aug_2,
    resize_aug
])

image_gray = grayscale_aug(image=image)

plt.imshow(image_gray)
plt.show()

## Laplace Noise

In [None]:
laplace = iaa.AdditiveLaplaceNoise(scale=(0.1*255, 0.3*255), per_channel=True)
laplace_aug = iaa.Sequential([
    laplace,
    resize_aug
])

image_laplace = laplace_aug(image=image)

plt.imshow(image_laplace)
plt.show()

## Blur

In [None]:
blur = iaa.AverageBlur(k=(12))
blur_aug = iaa.Sequential([
    blur,
    resize_aug
])

image_blur = blur_aug(image=image)

plt.imshow(image_blur)
plt.show()


## Contrast

In [None]:
contrast = iaa.LinearContrast((0.2, 0.6))
contrast_aug = iaa.Sequential([
    contrast,
    resize_aug
])

# Apply to an image (or batch)
image_contrast = contrast_aug(image=image)

plt.imshow(image_contrast)
plt.show()

In [None]:
from PIL import Image
import math

def combine_images(image1_path, image2_path, output_path=None):
    # Open images with alpha handling
    def load_image(path):
        img = Image.open(path)
        if img.mode in ('RGBA', 'LA'):
            background = Image.new('RGB', img.size, (0, 0, 0))
            background.paste(img, mask=img.split()[-1])
            img = background
        return img.convert('RGB')
    
    img1 = load_image(image1_path)
    img2 = load_image(image2_path)

    # Get dimensions
    w1, h1 = img1.size
    w2, h2 = img2.size

    # Determine orientation
    def get_orientation(w, h):
        return 'portrait' if h > w else 'landscape'
    
    orientation = get_orientation(w1, h1)
    if orientation != get_orientation(w2, h2):
        raise ValueError("Mismatched orientations")

    # Calculate scaling with integer precision
    if orientation == 'portrait':
        total = w1 + w2
        scale = 512 / total
        max_dim = max(h1, h2)
    else:
        total = h1 + h2
        scale = 512 / total
        max_dim = max(w1, w2)

    # Calculate exact dimensions using ceiling instead of floor
    def exact_scale(orig_dim):
        return math.ceil(orig_dim * scale)
    
    # Resize with edge crop to prevent anti-aliasing artifacts
    def resize_and_crop(img, target_w, target_h):
        return img.resize((target_w, target_h), Image.Resampling.NEAREST)
    
    if orientation == 'portrait':
        w1_new = exact_scale(w1)
        w2_new = 512 - w1_new  # Ensure exact total width
        h_new = min(exact_scale(h1), exact_scale(h2))
        
        img1 = resize_and_crop(img1, w1_new, h_new)
        img2 = resize_and_crop(img2, w2_new, h_new)
    else:
        h1_new = exact_scale(h1)
        h2_new = 512 - h1_new  # Ensure exact total height
        w_new = min(exact_scale(w1), exact_scale(w2))
        
        img1 = resize_and_crop(img1, w_new, h1_new)
        img2 = resize_and_crop(img2, w_new, h2_new)

    # Create combined image
    if orientation == 'portrait':
        combined = Image.new('RGB', (512, h_new))
        combined.paste(img1, (0, 0))
        combined.paste(img2, (w1_new, 0))
    else:
        combined = Image.new('RGB', (w_new, 512))
        combined.paste(img1, (0, 0))
        combined.paste(img2, (0, h1_new))

    # Final 512x512 image
    final_img = Image.new('RGB', (512, 512), (0, 0, 0))
    final_img.paste(combined, (
        (512 - combined.width) // 2,
        (512 - combined.height) // 2
    ))

    if output_path:
        final_img.save(output_path)
    return final_img

# Usage
combined = combine_images("Train/color/beagle_154.jpg", "Train/color/beagle_145.jpg")
plt.imshow(combined)
plt.show()

In [None]:
augmenters = [rotation_aug, center_crop_aug, random_crop_aug, masking_aug, grayscale_aug, laplace_aug, blur_aug, contrast_aug]

import os

folder_path = "Train/color"  # Update this with your actual folder path

# Get all image filenames with either .jpg or .png extensions.
filenames = [
    f for f in os.listdir(folder_path)
    if f.lower().endswith(('.jpg', '.png'))
]

# Helper function to extract the species name:
# It removes the file extension and gets everything before the last underscore.
def get_species(filename):
    base = os.path.splitext(filename)[0]  # Removes .jpg/.png extension
    return base.rsplit('_', 1)[0]

# Define the known cat breeds (the ones in your unique_species that are cats).
cat_species = {
    "Russian_Blue", "Siamese", "Sphynx", "Maine_Coon", "Abyssinian",
    "Bombay", "British_Shorthair", "Bengal", "Egyptian_Mau", "Persian",
    "Ragdoll", "Birman"
}

cat_files = []
dog_files = []

for fname in filenames:
    species = get_species(fname)
    # Remove the extension from the filename.
    name_no_ext = os.path.splitext(fname)[0]
    if species in cat_species:
        cat_files.append(name_no_ext)
    else:
        dog_files.append(name_no_ext)

import random

# Directories for source and destination images.
color_dir = "Train/color"
label_dir = "Train/label"
save_color_dir = "tmp/color"
save_label_dir = "tmp/label"

# Create destination directories if they do not exist.
os.makedirs(save_color_dir, exist_ok=True)
os.makedirs(save_label_dir, exist_ok=True)

print(len(cat_files))
print(len(dog_files))
cat_files = cat_files[:20]
dog_files = dog_files[:20]

for i, aug in enumerate(augmenters):
    # Try to use aug.name if it exists, otherwise fallback to an index-based name.
    aug_name = getattr(aug, "name", f"aug_{i}")

    # Randomly select 20% of the cat files and 20% of the dog files.
    num_cats = max(1, int(len(cat_files) * 0.2))
    num_dogs = max(1, int(len(dog_files) * 0.2))
    selected_cats = random.sample(cat_files, num_cats) if cat_files else []
    selected_dogs = random.sample(dog_files, num_dogs) if dog_files else []

    # Combine file lists.
    selected_files = selected_cats + selected_dogs

    print(f"Using augmenter '{aug_name}': processing {len(selected_files)} images")

    for fname in selected_files:
        # Build full file paths (color images are .jpg and labels are .png).
        color_path = os.path.join(color_dir, fname + ".jpg")
        label_path = os.path.join(label_dir, fname + ".png")

        # Check if both files exist.
        if not os.path.exists(color_path) or not os.path.exists(label_path):
            print(f"Skipping {fname}: missing color or label file.")
            continue

        # Read the input image and its label.
        img = imageio.imread(color_path)
        label = imageio.imread(label_path)

        # Create a segmentation map from the label (assumes label is a segmentation map).
        segmap = SegmentationMapsOnImage(label, shape=label.shape)

        # Apply the augmentation.
        augmented_img, augmented_segmap = aug(image=img, segmentation_maps=segmap)
        augmented_label = augmented_segmap.get_arr()

        augmented_label = augmented_label.astype(np.uint8)

        # Construct output file names. 
        # We append the augmenter name to the filename.
        out_color_path = os.path.join(save_color_dir, f"{fname}_{aug_name}_{i}.jpg")
        out_label_path = os.path.join(save_label_dir, f"{fname}_{aug_name}_{i}.png")

        # Save the augmented images.
        imageio.imwrite(out_color_path, augmented_img)
        imageio.imwrite(out_label_path, augmented_label)

        print(f"Saved augmented version of {fname} using {aug_name}")


In [None]:
import os
import random
import imageio
import numpy as np
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import math
import shutil # Added for copying original files

# --- DEFINE THE AUGMENTER DICTIONARY ---
augmenter_dict = {
    "rotation": rotation_aug,
    "center_crop": center_crop_aug,
    "random_crop": random_crop_aug,
    "masking": masking_aug,         # Using 'masking' as the key
    "grayscale": grayscale_aug,
    "laplace": laplace_aug,
    "blur": blur_aug,
    "contrast": contrast_aug
}
# --- End Augmenter Definitions ---

num_augmenters = len(augmenter_dict) # Count based on the dictionary

# --- Configuration ---
folder_path = "Train/color"
label_folder_path = "Train/label"
save_color_dir = "atrain/color"
save_label_dir = "atrain/label"
majority_aug_factor = 1.5

os.makedirs(save_color_dir, exist_ok=True)
os.makedirs(save_label_dir, exist_ok=True)

# --- File Discovery and Classification (Keep this section as is) ---
print("Scanning for image files...")
try:
    filenames = [
        f for f in os.listdir(folder_path)
        if f.lower().endswith(('.jpg', '.png'))
    ]
except FileNotFoundError:
    print(f"Error: Source directory not found at {folder_path}")
    exit()
print(f"Found {len(filenames)} potential image files.")

def get_species(filename):
    base = os.path.splitext(filename)[0]
    parts = base.rsplit('_', 1)
    return parts[0] if len(parts) > 1 else base

cat_species = {
    "Russian_Blue", "Siamese", "Sphynx", "Maine_Coon", "Abyssinian",
    "Bombay", "British_Shorthair", "Bengal", "Egyptian_Mau", "Persian",
    "Ragdoll", "Birman"
}
cat_files = []
dog_files = []
for fname in filenames:
    species = get_species(fname)
    name_no_ext = os.path.splitext(fname)[0]
    label_path_check = os.path.join(label_folder_path, name_no_ext + ".png")
    if not os.path.exists(label_path_check):
        print(f"Warning: Label file missing for {fname}, skipping this image.")
        continue
    if species in cat_species:
        cat_files.append(name_no_ext)
    else:
        dog_files.append(name_no_ext)

N_cat = len(cat_files)
N_dog = len(dog_files)

if N_cat == 0 and N_dog == 0:
    print("Error: No valid cat or dog images found.")
    exit()
print(f"Initial counts: Cats = {N_cat}, Dogs = {N_dog}")

# --- Copy Original Files ---
print("Processing originals with resize augmentation...")
all_original_files = cat_files + dog_files # Assuming cat_files/dog_files are populated
processed_count = 0

# Ensure destination directories exist
os.makedirs(save_color_dir, exist_ok=True)
os.makedirs(save_label_dir, exist_ok=True)

for fname in all_original_files:
    orig_color_path = os.path.join(folder_path, fname + ".jpg")
    orig_label_path = os.path.join(label_folder_path, fname + ".png")

    # Define destination paths (using original base name)
    dest_color_path = os.path.join(save_color_dir, fname + ".jpg")
    dest_label_path = os.path.join(save_label_dir, fname + ".png")

    # Avoid reprocessing if the files already exist in the destination
    if os.path.exists(dest_color_path) and os.path.exists(dest_label_path):
        # print(f"Skipping {fname}: Already processed and exists in destination.")
        processed_count += 1 # Count it as processed if it exists
        continue

    # Check if source files exist before trying to read
    if not os.path.exists(orig_color_path):
        print(f"Warning: Skipping {fname}, missing original color file: {orig_color_path}")
        continue
    if not os.path.exists(orig_label_path):
        print(f"Warning: Skipping {fname}, missing original label file: {orig_label_path}")
        continue

    try:
        # Read the input image and its label
        img = imageio.v2.imread(orig_color_path) # Use v2 for consistency
        label = imageio.v2.imread(orig_label_path)

        # Create a segmentation map object
        # Use img.shape because resize needs the original image shape reference
        segmap = SegmentationMapsOnImage(label, shape=img.shape)

        # Apply the resize augmentation to both image and label map
        resized_img, resized_segmap = resize_aug(image=img, segmentation_maps=segmap)
        resized_label = resized_segmap.get_arr()

        if resized_img.ndim == 3 and resized_img.shape[2] == 4:
            print(f"    Converting RGBA image to RGB for {fname}.jpg")
            resized_img = resized_img[..., :3] # Slice to keep only the first 3 channels (R, G, B)

        # Ensure correct data types before saving
        resized_img = resized_img.astype(np.uint8)
        resized_label = resized_label.astype(np.uint8)

        # Save the processed (resized) images
        imageio.imwrite(dest_color_path, resized_img)
        imageio.imwrite(dest_label_path, resized_label)

        processed_count += 1
        # print(f"Processed and saved resized version of {fname}")

    except Exception as e:
        print(f"Error processing {fname} with resize_aug: {e}")

print(f"Processed and saved {processed_count} original image/label pairs using resize_aug.")


# --- Calculate Augmentation Needs ---
if N_cat == N_dog:
    print("Dataset is already balanced.")
    # Decide if you still want to augment equally
    target_final_count = round(N_dog * majority_aug_factor) # Augment both by the same factor
elif N_cat < N_dog:
    print("Dogs are the majority class.")
    target_final_count = round(N_dog * majority_aug_factor)
else: # N_dog < N_cat
    print("Cats are the majority class.")
    target_final_count = round(N_cat * majority_aug_factor)

total_aug_cat_needed = max(0, target_final_count - N_cat)
total_aug_dog_needed = max(0, target_final_count - N_dog)

print(f"Target final count per class: {target_final_count}")
print(f"Total augmentations needed: Cats = {total_aug_cat_needed}, Dogs = {total_aug_dog_needed}")

if num_augmenters > 0:
    # Use ceiling to ensure we generate at least the required number
    num_cats_per_aug = math.ceil(total_aug_cat_needed / num_augmenters)
    num_dogs_per_aug = math.ceil(total_aug_dog_needed / num_augmenters)
else:
    print("Warning: No augmenters provided. Only copying original files.")
    num_cats_per_aug = 0
    num_dogs_per_aug = 0

print(f"Will select approximately {num_cats_per_aug} cats and {num_dogs_per_aug} dogs per augmenter.")

num_selected_cats = 0
num_selected_dogs = 0
# --- Augmentation Loop ---
generated_aug_count = 0
if num_augmenters > 0 and (num_cats_per_aug > 0 or num_dogs_per_aug > 0):
    # Iterate through the dictionary using enumerate to keep an index 'i' for uniqueness
    for i, (aug_name, aug_object) in enumerate(augmenter_dict.items()):

        # Sample files (same logic as before)
        selected_cats = random.choices(cat_files, k=num_cats_per_aug) if N_cat > 0 and num_cats_per_aug > 0 else []
        selected_dogs = random.choices(dog_files, k=num_dogs_per_aug) if N_dog > 0 and num_dogs_per_aug > 0 else []
        selected_files = selected_cats + selected_dogs
        num_selected_cats += len(selected_cats)
        num_selected_dogs += len(selected_dogs)

        if not selected_files:
            print(f"No files selected for augmenter '{aug_name}', skipping.")
            continue

        # Use the dictionary key 'aug_name' in the print statement
        print(f"\nUsing augmenter '{aug_name}' ({i+1}/{num_augmenters}): processing {len(selected_files)} images ({len(selected_cats)} cats, {len(selected_dogs)} dogs)")

        processed_in_batch = 0
        for fname in selected_files:
            # Paths and existence checks (same as before)
            color_path = os.path.join(folder_path, fname + ".jpg")
            label_path = os.path.join(label_folder_path, fname + ".png")

            if not os.path.exists(color_path):
                print(f"Skipping {fname}: missing color file ({color_path}).")
                continue
            if not os.path.exists(label_path):
                 print(f"Skipping {fname}: missing label file ({label_path}).")
                 continue

            try:
                # Read images
                img = imageio.v2.imread(color_path)
                label = imageio.v2.imread(label_path)
                segmap = SegmentationMapsOnImage(label, shape=img.shape)

                # Apply the augmentation using the 'aug_object' from the dictionary
                augmented_img, augmented_segmap = aug_object(image=img, segmentation_maps=segmap)
                augmented_label = augmented_segmap.get_arr()

                if augmented_img.ndim == 3 and augmented_img.shape[2] == 4:
                     print(f"    Converting augmented RGBA image to RGB for {fname}_{aug_name}_{processed_in_batch}.jpg")
                     augmented_img = augmented_img[..., :3] # Slice to keep only R, G, B

                # Type casting
                augmented_label = augmented_label.astype(np.uint8)
                augmented_img = augmented_img.astype(np.uint8)

                # Construct output file names using the dictionary key 'aug_name' and index 'i'
                # Also include processed_in_batch counter for uniqueness within the batch
                out_color_path = os.path.join(save_color_dir, f"{fname}_{aug_name}_{processed_in_batch}.jpg")
                out_label_path = os.path.join(save_label_dir, f"{fname}_{aug_name}_{processed_in_batch}.png")

                # Save the augmented images
                imageio.imwrite(out_color_path, augmented_img)
                imageio.imwrite(out_label_path, augmented_label)

                generated_aug_count += 1
                processed_in_batch += 1

            except Exception as e:
                 # Use the dictionary key 'aug_name' in the error message
                print(f"Error processing or saving {fname} with augmenter {aug_name}: {e}")

        # Use the dictionary key 'aug_name' in the summary message
        print(f"Augmenter '{aug_name}' finished. Processed {processed_in_batch} images.")

else:
    print("Skipping augmentation loop as no augmenters or augmentation needed.")

# --- Final Count Verification (REVISED FOR NAMING: fname_augname_batchcount) ---

print("\nVerifying final counts in output directory...")
try:
    final_color_files = os.listdir(save_color_dir)
except FileNotFoundError:
    print(f"Error: Output directory not found at {save_color_dir}")
    final_color_files = []

final_cat_count = 0
final_dog_count = 0
unclassified_count = 0

# Ensure augmenter_dict is defined before this block
# augmenter_dict = { "rotation": rotation_aug, ... } # Example
if 'augmenter_dict' not in locals() and 'augmenter_dict' not in globals():
     print("Error: augmenter_dict not found. Please define it before verification.")
     # You might want to exit or handle this error appropriately
     exit()

# Define known augmentation suffixes based on dictionary keys
known_aug_suffixes = set(augmenter_dict.keys()) # e.g., {"rotation", "blur", ...}

# Pre-compile a regex pattern to find the fname_augname_batchcount suffix
# Pattern: _(aug_name)_(digits) at the VERY end of the string.
# Capture the part *before* the suffix.
import re
# Create the OR part for known suffixes: (rotation|blur|contrast|...)
aug_name_pattern = "|".join(re.escape(name) for name in known_aug_suffixes)
# Pattern description:
# ^(.*?)             # Capture group 1: Everything from the start (non-greedy)
# _(?:{aug_name_pattern}) # Match _ followed by one of the known aug names (non-capturing group)
# _\d+               # Match _ followed by one or more digits (the batch count)
# $                  # Match the end of the string
suffix_pattern = re.compile(rf"^(.*?)_(?:{aug_name_pattern})_\d+$")


for fname_with_ext in final_color_files:
    if not fname_with_ext.lower().endswith(('.jpg', '.png')):
        continue

    base_name = os.path.splitext(fname_with_ext)[0]
    original_base = None # Reset for each file

    # Try to match the fname_augname_batchcount suffix pattern
    match = suffix_pattern.match(base_name)

    if match:
        # If pattern matched, the first group is the original base name
        original_base = match.group(1)
        # print(f"Debug: Matched suffix. File: {fname_with_ext}, Original Base: {original_base}") # Optional Debug
    else:
        # If pattern didn't match, assume it's an original resized file
        # The base_name itself is the original base
        original_base = base_name
        # print(f"Debug: No suffix match. File: {fname_with_ext}, Original Base: {original_base}") # Optional Debug

    if original_base:
        # Ensure get_species function is defined
        if 'get_species' not in locals() and 'get_species' not in globals():
             print("Error: get_species function not found.")
             exit()
        # Ensure cat_species set is defined
        if 'cat_species' not in locals() and 'cat_species' not in globals():
             print("Error: cat_species set not found.")
             exit()

        try:
            # Use get_species on the extracted original base name + dummy extension
            species = get_species(original_base + ".jpg")

            if species in cat_species:
                final_cat_count += 1
                # print(f"  -> Classified as CAT ({species})") # Optional Debug
            else:
                # Assuming non-cat is dog based on original problem description
                final_dog_count += 1
                # print(f"  -> Classified as DOG ({species})") # Optional Debug
        except Exception as e:
            print(f"    Warning: Error processing species for base '{original_base}' from file '{fname_with_ext}': {e}")
            unclassified_count += 1
    else:
        # This case might occur if the regex fails unexpectedly or base_name is empty
        print(f"    Warning: Could not determine original base for file: {fname_with_ext}")
        unclassified_count += 1

print(f"\nFinal counts in '{save_color_dir}':")
print(f" - Cats: {final_cat_count}")
print(f"selected_cats: {num_selected_cats}")
print(f" - Dogs: {final_dog_count}")
print(f"selected_dogs: {num_selected_dogs}")
if unclassified_count > 0:
    print(f" - Unclassified: {unclassified_count}")

# Display target for comparison - assuming target_final_count is still defined
try:
    # Ensure target_final_count is defined
    if 'target_final_count' not in locals() and 'target_final_count' not in globals():
         print("Target count variable (target_final_count) not found for comparison.")
    else:
         print(f"Target was approximately {target_final_count} per class.")

except NameError: # Catch just in case the variable definition itself failed earlier
     print("Target count variable (target_final_count) definition failed earlier.")

print("Final count verification complete.")


# Optional: Check total files vs expected
total_files_in_dir = len([f for f in final_color_files if f.lower().endswith(('.jpg', '.png'))])
print(f"Total .jpg/.png files found in output dir: {total_files_in_dir}")
# Assuming processed_count (originals) and generated_aug_count (augmentations attempted) are defined
try:
    # Check if counters are defined
    if ('processed_count' not in locals() and 'processed_count' not in globals()) or \
       ('generated_aug_count' not in locals() and 'generated_aug_count' not in globals()):
        print("Counters (processed_count, generated_aug_count) not found for comparison.")
    else:
        expected_total = processed_count + generated_aug_count
        print(f"Theoretical file count (originals + augmentations attempted): {expected_total}")
except NameError:
     pass # Ignore if counters weren't defined
