In [None]:
augmentation_pipeline = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),  # Add vertical flip
    A.ShiftScaleRotate(
        shift_limit=0.1, 
        scale_limit=0.1, 
        rotate_limit=15, 
        p=0.5, 
        interpolation=cv2.INTER_NEAREST  # Ensure masks aren't interpolated
    ),
])

# Function to map class indices back to RGB
def class_to_color(mask, label_color_map):
    h, w = mask.shape
    rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for color, label in label_color_map.items():
        rgb_mask[mask == label] = color
    return rgb_mask

def augment_and_save(images, masks, output_images_dir, output_masks_dir, augmentation_pipeline, num_augments=3, start_index=1):
    os.makedirs(output_images_dir, exist_ok=True)
    os.makedirs(output_masks_dir, exist_ok=True)

    current_index = start_index  # Start naming augmented images from 301.bmp
    for _ in range(num_augments):
        for img, mask in zip(images, masks):
            # Convert one-hot mask to class indices
            mask = np.argmax(mask, axis=-1)

            # Apply augmentation
            augmented = augmentation_pipeline(image=img, mask=mask)
            augmented_image = augmented['image']
            augmented_mask = augmented['mask']

            # Map augmented mask back to RGB
            rgb_mask = class_to_color(augmented_mask, label_color_map)

            # Save augmented image
            augmented_image_path = os.path.join(output_images_dir, f"{current_index}.bmp")
            cv2.imwrite(augmented_image_path, (augmented_image * 255).astype(np.uint8))  # Scale back to [0, 255]

            # Save augmented mask (in RGB format)
            augmented_mask_path = os.path.join(output_masks_dir, f"{current_index}.bmp")
            cv2.imwrite(augmented_mask_path, rgb_mask)  # Save as RGB mask

            current_index += 1  # Increment index for next augmented data

# Paths to save augmented data
output_segmentation_images_dir = '/kaggle/working/augmented_segmentation_images'
output_segmentation_masks_dir = '/kaggle/working/augmented_segmentation_masks'

# Augment segmentation data and save
augment_and_save(images, segmentation_masks, output_segmentation_images_dir, output_segmentation_masks_dir, augmentation_pipeline, num_augments=3)