In [None]:
import os
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split


# ============================================================
# CONFIGURATION
# ============================================================

INPUT_DIR = "UTKFace"
OUTPUT_DIR = "dataset_by_age"
IMG_SIZE = (128, 128)


# ============================================================
# LABEL EXTRACTION
# ============================================================

def extract_age_from_filename(filename):
    """
    Extract age from UTKFace-style filenames:
    - '1_0_0_20161219140623097.jpg'
    - '1_0_0_20161219140623097.jpg.chip.jpg'
    - '12_1_3_201701022015.png'
    
    Age is ALWAYS the first integer before the first underscore.
    """
    try:
        age_str = filename.split("_")[0]  # first token is age
        return int(age_str)
    except:
        return None


def age_to_category(age):
    if age < 30:
        return "joven"
    elif age <= 61:
        return "medio"
    else:
        return "anciano"


# ============================================================
# IMAGE PROCESSING
# ============================================================

def process_image(path):
    """Load, resize, and return a processed PIL image."""
    img = Image.open(path).convert("RGB")
    img = img.resize(IMG_SIZE)
    return img  # keep as PIL Image for saving


# ============================================================
# DIRECTORY CREATION
# ============================================================

def recreate_structure(base_dir, classes):
    for split in ["train", "validation", "test"]:
        for cls in classes:
            Path(base_dir, split, cls).mkdir(parents=True, exist_ok=True)


# ============================================================
# SCAN DATASET (LIGHTWEIGHT — NO LOADING)
# ============================================================

print("Scanning UTKFace image files...")

input_dir = Path(INPUT_DIR)
files = [f for f in input_dir.iterdir() if f.is_file()]

records = []
for f in tqdm(files):
    age = extract_age_from_filename(f.name)
    if age is None:
        continue
    category = age_to_category(age)
    records.append((f, category))

print(f"Valid images found: {len(records)}")

paths = np.array([r[0] for r in records])
labels = np.array([r[1] for r in records])

classes = sorted(list(set(labels)))
print("Classes:", classes)


# ============================================================
# STRATIFIED SPLIT (ONLY FILENAMES)
# ============================================================

train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    paths, labels,
    test_size=0.30,
    stratify=labels,
    random_state=42
)

val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels,
    test_size=0.5,
    stratify=temp_labels,
    random_state=42
)

splits = {
    "train": list(zip(train_paths, train_labels)),
    "validation": list(zip(val_paths, val_labels)),
    "test": list(zip(test_paths, test_labels)),
}


# ============================================================
# ENABLE DIRECTORY STRUCTURE
# ============================================================

recreate_structure(OUTPUT_DIR, classes)


# ============================================================
# SAVE IMAGES STREAMING (RAM-SAFE)
# ============================================================

def save_split(name, split_data, output_base):
    print(f"\nSaving {name}...")
    for path, label in tqdm(split_data):
        img = process_image(path)  # load + resize one image only
        dest = Path(output_base) / name / label / path.name
        img.save(dest)


save_split("train", splits["train"], OUTPUT_DIR)
save_split("validation", splits["validation"], OUTPUT_DIR)
save_split("test", splits["test"], OUTPUT_DIR)


print("\nDataset ready!")
print("Output directory:", OUTPUT_DIR)


Scanning UTKFace image files...


100%|██████████| 23708/23708 [00:00<00:00, 1366683.97it/s]


Valid images found: 23708
Classes: [np.str_('anciano'), np.str_('joven'), np.str_('medio')]

Saving train...


 86%|████████▌ | 14270/16595 [00:14<00:02, 1004.61it/s]


KeyboardInterrupt: 