In [5]:
import os
import shutil
from sklearn.model_selection import train_test_split
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as F
import random
import math
import sys

### Augmentation with balancing of classes


In [8]:
# Paths
source_dir = '/home/luka/Pictures/Kaninchen'
output_dir = '/home/luka/Pictures/Kaninchen_aug'

# Constants
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
valid_exts = ('.jpg', '.jpeg', '.png')
SPLITS = ['train', 'val', 'test']
RATIOS = {'train': 0.7, 'val': 0.15, 'test': 0.15}
random.seed(42)

# Augmentations
flip = T.RandomHorizontalFlip(p=1.0)

def rotate_crop_borders(img, angle):
    rotate_plus = T.RandomRotation(degrees=(angle, angle))
    img = rotate_plus(img)

    angle_deg = angle % 180
    if angle_deg > 90:
        angle_deg = 180 - angle_deg
    angle_rad = math.radians(angle_deg)

    sin_a = math.sin(angle_rad)
    cos_a = math.cos(angle_rad)

    if img.height * sin_a <= img.width * cos_a:
        new_w = (img.width * cos_a) - (img.height * sin_a)
        new_h = (img.height * cos_a) - (img.width * sin_a)
    else:
        new_w = (img.height * cos_a) - (img.width * sin_a)
        new_h = (img.width * cos_a) - (img.height * sin_a)

    if new_w <= 0 or new_h <= 0:
        return None

    left = (img.width - new_w) // 2
    top = (img.height - new_h) // 2
    right = left + new_w
    bottom = top + new_h

    if right <= left or bottom <= top:
        return None

    try:
        cropped_img = img.crop((left, top, right, bottom))
    except Exception:
        return None

    w, h = cropped_img.size
    if max(w, h) > 4 * min(w, h):
        return None

    return cropped_img

def get_class_files(class_path):
    return [f for f in os.listdir(class_path) if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]

def prepare_split_folders():
    for split in SPLITS:
        for class_name in os.listdir(source_dir):
            src_class_dir = os.path.join(source_dir, class_name)
            if os.path.isdir(src_class_dir):
                split_class_dir = os.path.join(output_dir, split, class_name)
                os.makedirs(split_class_dir, exist_ok=True)

def split_dataset():
    prepare_split_folders()

    for class_name in os.listdir(source_dir):
        src_class_dir = os.path.join(source_dir, class_name)
        if not os.path.isdir(src_class_dir):
            continue

        files = get_class_files(src_class_dir)
        if not files:
            continue

        # Split into train, val, test
        train_val, test = train_test_split(files, test_size=RATIOS['test'], random_state=42)
        train, val = train_test_split(train_val, test_size=RATIOS['val'] / (RATIOS['train'] + RATIOS['val']), random_state=42)
        split_map = {'train': train, 'val': val, 'test': test}

        for split in SPLITS:
            for f in split_map[split]:
                src = os.path.join(src_class_dir, f)
                dst = os.path.join(output_dir, split, class_name, f)
                shutil.copy2(src, dst)

def get_image_counts_per_split(split_root):
    image_counts = {}
    for class_name in os.listdir(split_root):
        class_path = os.path.join(split_root, class_name)
        if not os.path.isdir(class_path):
            continue
        count = len(get_class_files(class_path))
        image_counts[class_name] = count
    return image_counts

def augment_split(split):
    split_root = os.path.join(output_dir, split)
    image_counts = get_image_counts_per_split(split_root)

    class_names = list(image_counts.keys())
    max_img_count = max([v for k, v in image_counts.items() if k.lower() != 'other'])

    for class_name in class_names:
        class_path = os.path.join(split_root, class_name)
        files = get_class_files(class_path)
        current_count = image_counts[class_name]

        num_needed = max_img_count - current_count
        if num_needed <= 0:
            continue

        originals = []
        for file in files:
            path = os.path.join(class_path, file)
            name, ext = os.path.splitext(file)

            # Skip already augmented images
            if '_v' in name and name[-2:].isdigit():
                continue

            try:
                img = Image.open(path).convert("RGB")
                base_name = name
                if '_ok' in name:
                    base_name = name.split('_ok')[0] + '_ok'
                elif '_nok' in name:
                    base_name = name.split('_nok')[0] + '_nok'

                # Save v0 and delete original
                v0_name = f"{base_name}_v0{ext}"
                v0_path = os.path.join(class_path, v0_name)
                if not os.path.exists(v0_path):
                    img.save(v0_path)
                    os.remove(path)
                    print(f"[{split}] Saved _v0 and deleted original: {file}")
                originals.append((base_name, img))

            except Exception as e:
                print(f"[{split}] Failed to process {file}: {e}")

        # Augment each original with v1, v2, v3
        aug_count = 0
        for base_name, original in originals:
            if aug_count >= num_needed:
                break

            ext = ".jpg"  # standardize or detect from original

            # v1: flipped
            if aug_count < num_needed:
                try:
                    v1 = flip(original)
                    v1.save(os.path.join(class_path, f"{base_name}_v1{ext}"))
                    aug_count += 1
                except Exception as e:
                    print(f"[{split}] Failed to create v1 for {base_name}: {e}")

            # v2: rotated
            if aug_count < num_needed:
                try:
                    angle = random.uniform(5, 10)
                    v2 = rotate_crop_borders(original.copy(), angle)
                    if v2:
                        v2.save(os.path.join(class_path, f"{base_name}_v2{ext}"))
                        aug_count += 1
                except Exception as e:
                    print(f"[{split}] Failed to create v2 for {base_name}: {e}")

            # v3: rotated + flipped
            if aug_count < num_needed:
                try:
                    angle = random.uniform(5, 10)
                    rotated = rotate_crop_borders(original.copy(), angle)
                    if rotated:
                        v3 = flip(rotated)
                        v3.save(os.path.join(class_path, f"{base_name}_v3{ext}"))
                        aug_count += 1
                except Exception as e:
                    print(f"[{split}] Failed to create v3 for {base_name}: {e}")
def trim_other_class(split):
    split_root = os.path.join(output_dir, split)
    image_counts = get_image_counts_per_split(split_root)
    
    max_count = max([v for k, v in image_counts.items() if k.lower() != 'other'])
    other_path = os.path.join(split_root, 'other')

    if not os.path.exists(other_path):
        return

    other_images = [f for f in os.listdir(other_path) if f.lower().endswith(valid_exts)]
    if len(other_images) <= max_count:
        return

    # Randomly keep max_count files, delete rest
    to_keep = set(random.sample(other_images, max_count))
    for f in other_images:
        if f not in to_keep:
            try:
                os.remove(os.path.join(other_path, f))
                print(f"[{split}] Deleted extra from 'other': {f}")
            except Exception as e:
                print(f"[{split}] Failed to delete {f}: {e}")

def flatten_split_folders(base_output_dir):
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(base_output_dir, split)
        if not os.path.exists(split_path):
            continue

        for class_name in os.listdir(split_path):
            class_folder = os.path.join(split_path, class_name)
            if not os.path.isdir(class_folder):
                continue

            for file in os.listdir(class_folder):
                src_file = os.path.join(class_folder, file)
                if not file.lower().endswith(valid_exts):
                    continue

                # Add class name prefix if not already present
                if not file.startswith(f"{class_name}_"):
                    new_name = f"{class_name}_{file}"
                else:
                    new_name = file

                dst_file = os.path.join(split_path, new_name)

                # Avoid overwriting files with the same name
                if os.path.exists(dst_file):
                    base, ext = os.path.splitext(new_name)
                    suffix = 1
                    while os.path.exists(os.path.join(split_path, f"{base}_{suffix}{ext}")):
                        suffix += 1
                    dst_file = os.path.join(split_path, f"{base}_{suffix}{ext}")

                try:
                    shutil.move(src_file, dst_file)
                except Exception as e:
                    print(f"[{split}] Failed to move {src_file} → {dst_file}: {e}")

            # Remove the now-empty class folder
            try:
                os.rmdir(class_folder)
                print(f"[{split}] Removed folder: {class_folder}")
            except Exception as e:
                print(f"[{split}] Could not remove folder {class_folder}: {e}")

split_dataset()


for split in SPLITS:
    print(f"\nAugmenting {split}...")
    augment_split(split)
    trim_other_class(split)

flatten_split_folders(output_dir)
print("\n✅ Done. Dataset split and augmented.")


Augmenting train...
[train] Saved _v0 and deleted original: IMG_9657.JPEG
[train] Saved _v0 and deleted original: IMG_9731.JPG
[train] Saved _v0 and deleted original: IMG_9771.JPEG
[train] Saved _v0 and deleted original: IMG_9872.JPEG
[train] Saved _v0 and deleted original: IMG_9780.JPEG
[train] Saved _v0 and deleted original: IMG_9727.JPG
[train] Saved _v0 and deleted original: IMG_9751.JPEG
[train] Saved _v0 and deleted original: IMG_9782.JPEG
[train] Saved _v0 and deleted original: IMG_9903.JPEG
[train] Saved _v0 and deleted original: IMG_9699.JPG
[train] Saved _v0 and deleted original: IMG_9701.JPG
[train] Saved _v0 and deleted original: IMG_9772.JPEG
[train] Saved _v0 and deleted original: IMG_9774.JPEG
[train] Saved _v0 and deleted original: IMG_9785.JPEG
[train] Saved _v0 and deleted original: IMG_9697.JPG
[train] Saved _v0 and deleted original: IMG_9738.JPG
[train] Saved _v0 and deleted original: IMG_9784.JPEG
[train] Saved _v0 and deleted original: IMG_9695.JPG
[train] Saved 

In [11]:
def print_flat_split_statistics(base_output_dir):
    print("Dataset statistics per split and class (from filename prefixes):\n")
    for split in SPLITS:
        split_path = os.path.join(base_output_dir, split)
        if not os.path.exists(split_path):
            print(f"{split}: folder does not exist.")
            continue
        class_counts = {}
        v0_counts = {}
        non_v0_counts = {}
        files = [f for f in os.listdir(split_path) if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
        for f in files:
            prefix = f.split('_')[0]
            class_counts[prefix] = class_counts.get(prefix, 0) + 1
            if '_v0' in f:
                v0_counts[prefix] = v0_counts.get(prefix, 0) + 1
            else:
                non_v0_counts[prefix] = non_v0_counts.get(prefix, 0) + 1
        print(f"{split.capitalize()}:")
        for cname in sorted(class_counts.keys()):
            total = class_counts[cname]
            v0 = v0_counts.get(cname, 0)
            non_v0 = non_v0_counts.get(cname, 0)
            print(f"  {cname}: {total} (v0: {v0}, not v0: {non_v0})")
        print()

print_flat_split_statistics(output_dir)


Dataset statistics per split and class (from filename prefixes):

Train:
  Apollo: 88 (v0: 46, not v0: 42)
  Helios: 88 (v0: 57, not v0: 31)
  Nyx: 88 (v0: 28, not v0: 60)
  Selene: 88 (v0: 0, not v0: 88)

Val:
  Apollo: 19 (v0: 10, not v0: 9)
  Helios: 19 (v0: 13, not v0: 6)
  Nyx: 19 (v0: 7, not v0: 12)
  Selene: 19 (v0: 0, not v0: 19)

Test:
  Apollo: 20 (v0: 11, not v0: 9)
  Helios: 20 (v0: 13, not v0: 7)
  Nyx: 20 (v0: 7, not v0: 13)
  Selene: 20 (v0: 0, not v0: 20)

