In [13]:
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img

In [8]:
import os
import random
from PIL import Image, ImageEnhance
from collections import defaultdict


BASE_PATH = "/kaggle/input"
DATASET_PATH = os.path.join(BASE_PATH, os.listdir(BASE_PATH)[0])
DATA_ROOT = os.path.join(DATASET_PATH, os.listdir(DATASET_PATH)[0])

RANDOM_SEED = 42


random.seed(RANDOM_SEED)

#to remove corrupt files

VALID_EXTENSIONS = (".jpg", ".jpeg", ".png")

In [9]:
class_images = defaultdict(list)

for cls in os.listdir(DATA_ROOT):
    class_dir = os.path.join(DATA_ROOT, cls)
    if not os.path.isdir(class_dir):
        continue

    for file in os.listdir(class_dir):
        if file.lower().endswith(VALID_EXTENSIONS):
            file_path = os.path.join(class_dir, file)
            try:
                with Image.open(file_path) as img:
                    img.verify()
                class_images[cls].append(file_path)
            except Exception:
                continue

In [17]:
maximum = max(len(v) for v in class_images.values())
maximum

3269

In [18]:
def augment_image(img):
    if random.random()<0.5:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)

    img = img.rotate(random.uniform(-20,20),expand=True)

    img = ImageEnhance.Brightness(img).enhance(
        random.uniform(0.8,1.2)
    )

    img = ImageEnhance.Contrast(img).enhance(
        random.uniform(0.8, 1.2)
    )

    return img



In [20]:
SAVE_ROOT = "/kaggle/working/augmented_train"
os.makedirs(SAVE_ROOT, exist_ok=True)

In [21]:
for cls, images in class_images.items():
    deficit = maximum - len(images)
    if deficit <= 0:
        continue

    print(f"Augmenting class: {cls}")

    save_cls_dir = os.path.join(SAVE_ROOT, cls)
    os.makedirs(save_cls_dir, exist_ok=True)

    for i in range(deficit):
        img_path = random.choice(images)
        try:
            with Image.open(img_path).convert("RGB") as img:
                aug_img = augment_image(img)
                save_path = os.path.join(
                    save_cls_dir, f"aug_{i}.jpg"
                )
                aug_img.save(save_path)
        except Exception:
            continue

Augmenting class: planet
Augmenting class: galaxy
Augmenting class: black hole
Augmenting class: asteroid
Augmenting class: comet
Augmenting class: constellation
Augmenting class: nebula
