In [5]:
import os
import random
import shutil
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def remove_extra_images(class_dir, target_count, remove_augmented_only=False):
    #Randomly delete images in a directory to reach the target count. Optionally only delete augmented images.
    images = os.listdir(class_dir)

    # Filter to only remove augmented images if specified
    if remove_augmented_only:
        images = [img for img in images if img.startswith('aug')]

    num_existing = len(images)
    
    if num_existing > target_count:
        images_to_delete = random.sample(images, num_existing - target_count)

        for image in images_to_delete:
            image_path = os.path.join(class_dir, image)
            try:
                os.remove(image_path)
            except Exception as e:
                pass

def augment_class(class_dir, target_count, output_dir, batch_size=500):
    #Generate augmented images in smaller batches to avoid system overload.
    datagen = ImageDataGenerator(
        rotation_range=10,
        width_shift_range=0.05,
        height_shift_range=0.05,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest',
        brightness_range=[0.9, 1.1]
    )

    existing_images = os.listdir(class_dir)
    num_existing = len(existing_images)
    num_needed = target_count - num_existing

    if num_needed > 0:
        augmented = 0

        for image_file in existing_images:
            image_path = os.path.join(class_dir, image_file)
            img = tf.keras.preprocessing.image.load_img(image_path, target_size=(32, 32)) 
            x = tf.keras.preprocessing.image.img_to_array(img)
            x = np.expand_dims(x, axis=0)

            # Perform augmentation in batches
            for batch in datagen.flow(x, batch_size=1, save_to_dir=output_dir, save_prefix='aug', save_format='png'):
                augmented += 1

                if augmented >= num_needed:
                    return  

def balance_test_set(test_class_dir, target_count):
    #Balance the number of test images for each class without augmentation.
    existing_images = os.listdir(test_class_dir)
    num_existing = len(existing_images)
    if num_existing > target_count:
        remove_extra_images(test_class_dir, target_count)

def process_set_with_balancing_and_augmentation(set_number):
    base_path = f"D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_{set_number}"
    train_path = os.path.join(base_path, 'train')
    val_path = os.path.join(base_path, 'val')
    test_path = "D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/test"

    target_train_count = 10000  
    target_val_count = 2500  
    target_test_count = 2000  

    classes = ['Building', 'Land_(unpaved_area)', 'Road', 'Vegetation', 'Water']

    for class_name in classes:
        train_class_dir = os.path.join(train_path, class_name)
        val_class_dir = os.path.join(val_path, class_name)
        test_class_dir = os.path.join(test_path, class_name)

        # Augment "Road" and "Vegetation" classes in batches
        if class_name in ['Road', 'Vegetation', 'Water']:
            augment_class(train_class_dir, target_train_count, train_class_dir, batch_size=500)
            augment_class(val_class_dir, target_val_count, val_class_dir, batch_size=500)
            # Only remove augmented images for these classes
            remove_extra_images(train_class_dir, target_train_count, remove_augmented_only=True)
            remove_extra_images(val_class_dir, target_val_count, remove_augmented_only=True)
        else:
            # Directly balance other classes like "Building", "Land"
            remove_extra_images(train_class_dir, target_train_count)
            remove_extra_images(val_class_dir, target_val_count)

        # Balance the test data without augmentation
        balance_test_set(test_class_dir, target_test_count)

for set_number in range(1, 6):
    process_set_with_balancing_and_augmentation(set_number)


No images removed from D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\train\Building. Already under the target of 10000.
No images removed from D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\val\Building. Already under the target of 2500.
Test set for Building already has 1590 images.
No images removed from D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\train\Land_(unpaved_area). Already under the target of 10000.
No images removed from D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\val\Land_(unpaved_area). Already under the target of 2500.
Test set for Land_(unpaved_area) already has 2000 images.
No augmentation needed for D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\train\Road. Already at or above target of 10000.
No augmentation needed for D:/project_geo/code_test_combined/5folds_with_test_updated_combined_ver3/set_1\val\Ro