# Augment the dataset with Albumentations

In [21]:
import os
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from glob import glob

def create_dirs(output_image_dir, output_label_dir):
    os.makedirs(output_image_dir, exist_ok=True)
    os.makedirs(output_label_dir, exist_ok=True)

def load_image_label(image_path, label_path):
    image = np.load(image_path)  # Shape: (48, 96, 96)
    label = np.load(label_path)  # Shape: (96, 96)
    
    # Ensure image has shape (48, 96, 96) and label has shape (1, 96, 96)
    image = image[:, :96, :96]
    label = label[:96, :96].reshape(1, 96, 96)  # Ensure label is 1x96x96
    
    return image, label

# Albumentations used in the study 
def get_augmentations():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Rotate(limit=15, p=0.5),
        A.Rotate(limit=30, p=0.5),
        A.Rotate(limit=60, p=0.5),
        A.SafeRotate(limit=45, p=0.5, rotate_method="ellipse"),
        #A.SafeRotate(limit=60, p=1.0),
        #A.SafeRotate(limit=135, p=1.0),
        A.ElasticTransform(alpha=50, sigma=50, p=0.2),
    ], is_check_shapes=False)  # Disable shape checking due to multi-band images

def apply_augmentations(image, label, transform):
    # Albumentations expects (H, W, C), so transpose image to (96, 96, 48)
    image_transposed = np.transpose(image, (1, 2, 0))  # (96, 96, 48)
    label_transposed = np.transpose(label, (1, 2, 0))  # (96, 96, 1)
    
    augmented = transform(image=image_transposed, mask=label_transposed)
    
    # Convert back to (48, 96, 96) for image and (1, 96, 96) for label
    aug_image = np.transpose(augmented['image'], (2, 0, 1))  # (48, 96, 96)
    aug_label = np.transpose(augmented['mask'], (2, 0, 1))    # (1, 96, 96)
    
    return aug_image, aug_label

def save_augmented(image, label, image_path, label_path, output_image_dir, output_label_dir, idx):
    image_name = os.path.basename(image_path).replace(".npy", f"_aug{idx}.npy")
    label_name = os.path.basename(label_path).replace(".npy", f"_aug{idx}.npy")
    
    np.save(os.path.join(output_image_dir, image_name), image)
    np.save(os.path.join(output_label_dir, label_name), label)

def process_images(image_folder, label_folder, output_image_dir, output_label_dir, num_augments=3):
    create_dirs(output_image_dir, output_label_dir)
    
    image_paths = sorted(glob(os.path.join(image_folder, "*.npy")))
    label_paths = sorted(glob(os.path.join(label_folder, "*.npy")))
    
    transform = get_augmentations()
    
    for img_path, lbl_path in zip(image_paths, label_paths):
        image, label = load_image_label(img_path, lbl_path)
        
        for i in range(num_augments):
            aug_image, aug_label = apply_augmentations(image, label, transform)
            save_augmented(aug_image, aug_label, img_path, lbl_path, output_image_dir, output_label_dir, i)
    
if __name__ == "__main__":
    image_folder = "E:/ML/Levees/Datasets/N48/dataset_bin_48/npy_images" # The location of the files must be changed according to your needs
    label_folder = "E:/ML/Levees/Datasets/N48/dataset_bin_48/npy_masks" # The location of the files must be changed according to your needs
    output_image_dir = "E:/ML/Levees/Datasets/N48/dataset_bin_48/npy_images_" # The location of the files must be changed according to your needs
    output_label_dir = "E:/ML/Levees/Datasets/N48/dataset_bin_48/npy_masks_" # The location of the files must be changed according to your needs
   
    process_images(image_folder, label_folder, output_image_dir, output_label_dir)
