In [None]:
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
import os

# 클래스 목록
classes = ['green_bus', 'red_bus', 'blue_bus', 'yellow_bus']

# ImageDataGenerator 설정
datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# 데이터 증식 및 저장
base_dir = 'bus'
augmented_dir = 'bus_data'
target_images_per_class = 1000
img_size = (150, 150)

for cls in classes:
    input_dir = os.path.join(base_dir, cls)
    output_dir = os.path.join(augmented_dir, cls)

    # 증식된 이미지를 저장할 폴더 생성
    os.makedirs(output_dir, exist_ok=True)

    # 원본 이미지 파일 목록
    image_files = [f for f in os.listdir(input_dir) if f.endswith('.png')]
    total_images = len(image_files)
    augment_per_image = (target_images_per_class - total_images) // total_images + 1

    print(f"{cls} : Each image will be augmented {augment_per_image} times.")

    # 각 이미지를 증강하여 필요한 수의 이미지 생성
    for image_file in image_files:
        img_path = os.path.join(input_dir, image_file)
        img = load_img(img_path, target_size=img_size)
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)

        i = 0
        for batch in datagen.flow(x, batch_size=1, save_to_dir=output_dir, save_prefix='aug', save_format='png'):
            i += 1
            if i >= augment_per_image:
                break

    # 증강된 이미지가 충분하지 않을 경우 추가 증강
    current_augmented_images = len(os.listdir(output_dir))
    while current_augmented_images < target_images_per_class:
        for image_file in image_files:
            img_path = os.path.join(input_dir, image_file)
            img = load_img(img_path, target_size=img_size)
            x = img_to_array(img)
            x = x.reshape((1,) + x.shape)

            for batch in datagen.flow(x, batch_size=1, save_to_dir=output_dir, save_prefix='aug', save_format='png'):
                current_augmented_images += 1
                if current_augmented_images >= target_images_per_class:
                    break
            if current_augmented_images >= target_images_per_class:
                break

print(f"Total augmented images saved to {augmented_dir}")