In [None]:
import os
import random
from PIL import Image, ImageEnhance
from collections import defaultdict

# ==============================
# Locate dataset automatically
# ==============================
BASE_PATH = "/kaggle/input"
DATASET_PATH = os.path.join(BASE_PATH, os.listdir(BASE_PATH)[0])
DATA_ROOT = os.path.join(DATASET_PATH, os.listdir(DATASET_PATH)[0])

# ==============================
# Load dataset structure
# ==============================
class_images = {}

for cls in os.listdir(DATA_ROOT):
    class_dir = os.path.join(DATA_ROOT, cls)
    if os.path.isdir(class_dir):
        class_images[cls] = [
            os.path.join(class_dir, img)
            for img in os.listdir(class_dir)
            if img.lower().endswith((".jpg", ".jpeg", ".png"))
        ]

# ==============================
# Show class distribution
# ==============================
print("Original class distribution:")
for cls, imgs in class_images.items():
    print(f"{cls}: {len(imgs)} images")

# ==============================
# Identify minority classes
# ==============================
max_count = max(len(v) for v in class_images.values())

# ==============================
# Augmentation function
# ==============================
def augment_image(img):
    if random.random() > 0.5:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)

    img = img.rotate(random.choice([90, 180, 270]))

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(random.uniform(0.7, 1.3))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(random.uniform(0.7, 1.3))

    return img

# ==============================
# Perform augmentation (in-memory)
# ==============================
augmented = defaultdict(list)

for cls, images in class_images.items():
    deficit = max_count - len(images)
    if deficit <= 0:
        continue

    print(f"\nAugmenting class: {cls}")
    for _ in range(deficit):
        img_path = random.choice(images)
        img = Image.open(img_path).convert("RGB")
        aug_img = augment_image(img)
        augmented[cls].append(aug_img)

# ==============================
# Summary
# ==============================
print("\nAugmentation summary:")
for cls, imgs in augmented.items():
    print(f"{cls}: {len(imgs)} augmented samples")

print("\nâœ… Dataset balanced virtually (no files written).")

Original class distribution:
planet: 1472 images
galaxy: 3984 images
black hole: 656 images
asteroid: 283 images
comet: 416 images
star: 3269 images
constellation: 1552 images
nebula: 1192 images

Augmenting class: planet


### Augmentation Strategy

Since the dataset does not contain predefined train/test splits, 
augmentation was performed directly on class folders.

To avoid storage issues, augmented images were generated in memory only.
This ensures balanced class representation during training while 
keeping the dataset unchanged on disk.