In [None]:
import numpy as np
import imgaug.augmenters as iaa
import matplotlib.pyplot as plt
from glob import glob
import os
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

data_path = 'C:/Users/teele_k/Downloads/2d_dataset/data_cleaned_normalized_augmented'

npy_files = glob(os.path.join(data_path, '*.npy'))

# Define the number of new images you want
desired_number_of_images = 2000
additional_images_needed = desired_number_of_images - len(npy_files)

# Define an augmentation sequence
seq = iaa.Sequential([
    iaa.Affine(
        rotate=(-10, 10),  # slight rotations
        translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},  # slight shifts
        scale=(0.9, 1.1)  # uniform scaling
    ),
    iaa.ElasticTransformation(alpha=(0, 5.0), sigma=0.25),  # realistic elastic transformations
    iaa.MotionBlur(k=(3, 5))  # slight motion blur
])

def augment_image_and_annotation(image, annotation):
    # Convert the float annotations to integers
    annotation_int = np.round(annotation).astype(np.int32)
    segmap = SegmentationMapsOnImage(annotation_int, shape=image.shape)
    
    # Apply the same augmentation to image and annotation
    image_aug, segmap_aug = seq(image=image, segmentation_maps=segmap)
    return image_aug, segmap_aug.get_arr()

# Augment the dataset
augmented_files = 0
for file in npy_files:
    if augmented_files >= additional_images_needed:
        break
    data = np.load(file)
    image, annotation = data[0], data[1]
    image_aug, annotation_aug = augment_image_and_annotation(image, annotation)

    # Save the augmented data to a new .npy file
    new_file_name = os.path.join(data_path, f'aug_{augmented_files}.npy')
    np.save(new_file_name, np.array([image_aug, annotation_aug]))
    
    augmented_files += 1

print(f"Augmentation completed. {augmented_files} new images created.")




# Visualize the augmented images

npy_files = glob(os.path.join(data_path, '*.npy'))

# Number of files you want to visualize
number_to_visualize = 50

# Function to load and display images and annotations
def visualize_images(files, num_images):
    for i, file in enumerate(files):
        if i >= num_images:
            break
        data = np.load(file)
        image, annotation = data[0], data[1]

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap='gray')
        plt.title('Image')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(annotation, cmap='gray')
        plt.title('Annotation')
        plt.axis('off')

        plt.show()

visualize_images(npy_files, number_to_visualize)