In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator
from google.colab import drive

drive.mount('/content/drive')

# Define function to calculate maximum number of images among all classes
def calculate_max_images(data_dir):
    max_images = 0
    for image_class in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, image_class)
        num_images = len(os.listdir(class_dir))
        max_images = max(max_images, num_images)
    return max_images


# Define function to balance dataset with augmentation
def balance_dataset_with_augmentation(data_dir):
    max_images = calculate_max_images(data_dir)
    datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )
    for image_class in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, image_class)
        images = os.listdir(class_dir)
        num_images = len(images)
        if num_images < max_images:
            # Calculate number of augmentation steps needed
            steps_needed = max_images - num_images
            batch_size = min(32, steps_needed)  # Choose a reasonable batch size
            steps_per_epoch = steps_needed // batch_size + 1

            # Generate augmented images and save them
            image_files = [os.path.join(class_dir, image) for image in images]
            augmented_images = datagen.flow(np.array([load_img(img, target_size=(256,256)) for img in image_files]),
                                            batch_size=batch_size,
                                            save_to_dir=class_dir,
                                            save_prefix='augmented_',
                                            save_format='png')
            for _ in range(steps_per_epoch):
              augmented_images.next()


# Specify the path to your dataset folder
data_dir = '/content/drive/MyDrive/EM401/Augmentation/train'

# Apply the balanced dataset with augmentation
balance_dataset_with_augmentation(data_dir)

Mounted at /content/drive
