In [1]:
import os
import random
import shutil
from pathlib import Path 
import sys
from pathlib import Path
project_root = Path.cwd().parent
sys.path.append(str(project_root))
data_dir = project_root / "data/resized_data/"
target_dir = project_root / "data/vae/"


In [2]:
def split_dataset(source_dir, output_dir, train_pct=0.7, val_pct=0.15, test_pct=0.15, seed=42):

    random.seed(seed)
    
    # Convert directories to Path objects
    source_dir = Path(source_dir)
    output_dir = Path(output_dir)
    
    # Check that the percentages add up to 1.0
    if abs(train_pct + val_pct + test_pct - 1.0) > 1e-6:
        raise ValueError("The sum of train_pct, val_pct, and test_pct must be 1.0")
    
    # Define the splits (using "val" instead of "validation" per your desired structure)
    splits = ["train", "val", "test"]
    
    # Loop over each class folder in the source directory
    for class_dir in source_dir.iterdir():
        if class_dir.is_dir():
            class_name = class_dir.name
            # Find all image files (modify the extensions list if needed)
            images = list(class_dir.glob("*.*"))
            random.shuffle(images)
            total = len(images)
            n_train = int(total * train_pct)
            n_val = int(total * val_pct)
            # The test count is the remainder.
            n_test = total - n_train - n_val

            # Split the images
            train_imgs = images[:n_train]
            val_imgs = images[n_train:n_train+n_val]
            test_imgs = images[n_train+n_val:]
            
            print(f"Class '{class_name}': total={total}, train={len(train_imgs)}, val={len(val_imgs)}, test={len(test_imgs)}")
            
            # For each split, create a destination folder under the class folder and copy the files.
            for split, imgs in zip(splits, [train_imgs, val_imgs, test_imgs]):
                # New structure: output_dir / class_name / split
                split_class_dir = output_dir / class_name / split
                split_class_dir.mkdir(parents=True, exist_ok=True)
                for img_path in imgs:
                    shutil.copy2(img_path, split_class_dir / img_path.name)

In [3]:
split_dataset(data_dir, target_dir, 0.7, 0.2, 0.1)

Class 'zebra': total=376, train=263, val=75, test=38
Class 'cat': total=2852, train=1996, val=570, test=286
Class 'buffalo': total=376, train=263, val=75, test=38
Class 'rabbit': total=938, train=656, val=187, test=95
Class 'sheep': total=1820, train=1274, val=364, test=182
Class 'elephant': total=12037, train=8425, val=2407, test=1205
Class 'mouse': total=570, train=399, val=114, test=57
Class 'cow': total=1866, train=1306, val=373, test=187
Class 'horse': total=2623, train=1836, val=524, test=263
Class 'spider': total=4821, train=3374, val=964, test=483
Class 'rhino': total=376, train=263, val=75, test=38
Class 'squirrel': total=1862, train=1303, val=372, test=187
Class 'fox': total=6499, train=4549, val=1299, test=651
Class 'tiger': total=6976, train=4883, val=1395, test=698
Class 'dog': total=4967, train=3476, val=993, test=498
Class 'hen': total=3098, train=2168, val=619, test=311
Class 'butterfly': total=2112, train=1478, val=422, test=212
Class 'bird': total=1528, train=1069, va