In [334]:
from PIL import Image
import numpy as np
import os
import matplotlib as plt
import cv2
filepath = "/Users/natejly/Desktop/sorted_digits"
# filepath = "/Users/natejly/Desktop/chest_xray"
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Sequence, Tuple, Callable
import math

import numpy as np
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
import tensorflow as tf
import os
import json
from pathlib import Path


In [335]:
# Generate folder tree as JSON for LLM input, including file count and folder count for each folder, and collect leaf files
def folder_tree_json(path):
    leaf_files = []
    path = Path(path)
    def build_tree(p):
        temp_leaf_files = []
        children = [c for c in p.iterdir() if c.is_dir()]
        is_leaf = len(children) == 0
        folder_count = len(children)
        if is_leaf:
            for f in p.iterdir():
                if f.is_file():
                    temp_leaf_files.append(str(f))
            file_count = len(temp_leaf_files)
            leaf_files.extend(temp_leaf_files)
        else:
            file_count = None
        return {
            "folder_name": p.name,
            "is_leaf": is_leaf,
            "file_count": file_count,
            "folder_count": folder_count,
            "sub_folders": [build_tree(child) for child in children]
        }
    tree = build_tree(path)
    return tree, leaf_files
def get_image_size(leaf_files, sample_rate=50):
    sampled_files = leaf_files[::sample_rate]
    sizes = []
    pxtodim = {}
    for file in sampled_files:
        with Image.open(file) as img:
            pixels = 1
            for dim in img.size:
                pixels *= dim
            pxtodim[pixels] = img.size
            sizes.append(pixels)

    median_size = np.median(sizes)
    return pxtodim.get(median_size)



In [336]:
#sample every 1/20th file in the leaf files

tree_json, leaf_files = folder_tree_json(filepath)
img_dims = get_image_size(leaf_files, sample_rate=100)
print(json.dumps(tree_json, indent=2))
print(img_dims)

{
  "folder_name": "sorted_digits",
  "is_leaf": false,
  "file_count": null,
  "folder_count": 13,
  "sub_folders": [
    {
      "folder_name": "test",
      "is_leaf": false,
      "file_count": null,
      "folder_count": 10,
      "sub_folders": [
        {
          "folder_name": "9",
          "is_leaf": true,
          "file_count": 152,
          "folder_count": 0,
          "sub_folders": []
        },
        {
          "folder_name": "0",
          "is_leaf": true,
          "file_count": 147,
          "folder_count": 0,
          "sub_folders": []
        },
        {
          "folder_name": "7",
          "is_leaf": true,
          "file_count": 155,
          "folder_count": 0,
          "sub_folders": []
        },
        {
          "folder_name": "6",
          "is_leaf": true,
          "file_count": 145,
          "folder_count": 0,
          "sub_folders": []
        },
        {
          "folder_name": "1",
          "is_leaf": true,
          "file_count": 

In [337]:
def analyze_dataset_structure_and_splits(tree_json, filepath):
    """
    Analyze dataset structure to detect train/test/val splits and organize folder paths
    """
    # Check if dataset has train/test/val split structure
    split_keywords = ['train', 'test', 'val', 'validation', 'dev']
    
    def has_split_structure(node):
        folder_names = [child['folder_name'].lower() for child in node.get('sub_folders', [])]
        return any(keyword in name for keyword in split_keywords for name in folder_names)
    
    has_splits = has_split_structure(tree_json)
    
    # Organize file paths
    dataset_splits = {
        "train_test_val_split_exists": has_splits,
        "splits": {}
    }
    
    if has_splits:
        # Organize by detected splits
        def organize_by_splits(node, current_path=""):
            for child in node.get('sub_folders', []):
                child_name = child['folder_name'].lower()
                path_key = f"{current_path}/{child['folder_name']}" if current_path else child['folder_name']
                
                # Determine split type
                split_type = None
                if 'train' in child_name:
                    split_type = 'train'
                elif 'test' in child_name:
                    split_type = 'test'
                elif 'val' in child_name or 'validation' in child_name:
                    split_type = 'val'
                
                if split_type and child.get('is_leaf'):
                    # Leaf folder with files
                    if split_type not in dataset_splits["splits"]:
                        dataset_splits["splits"][split_type] = {}
                    
                    class_name = child['folder_name']
                    if current_path:
                        folder_path = f"{filepath}/{current_path}/{child['folder_name']}"
                    else:
                        folder_path = f"{filepath}/{child['folder_name']}"
                    dataset_splits["splits"][split_type][class_name] = folder_path
                
                elif split_type and not child.get('is_leaf'):
                    # Split folder with class subfolders
                    if split_type not in dataset_splits["splits"]:
                        dataset_splits["splits"][split_type] = {}
                    
                    for class_folder in child.get('sub_folders', []):
                        if class_folder.get('is_leaf'):
                            class_name = class_folder['folder_name']
                            folder_path = f"{filepath}/{path_key}/{class_name}"
                            dataset_splits["splits"][split_type][class_name] = folder_path
                
                # Recursively check subfolders
                organize_by_splits(child, path_key)
        
        organize_by_splits(tree_json)
    else:
        # No splits detected - put all files in training set
        dataset_splits["splits"]["train"] = {}
        
        def organize_all_as_train(node, current_path=""):
            if node.get('is_leaf'):
                class_name = node['folder_name']
                if current_path:
                    folder_path = f"{filepath}/{current_path}"
                else:
                    folder_path = f"{filepath}"
                dataset_splits["splits"]["train"][class_name] = folder_path
            
            for child in node.get('sub_folders', []):
                child_path = f"{current_path}/{child['folder_name']}" if current_path else child['folder_name']
                organize_all_as_train(child, child_path)
        
        organize_all_as_train(tree_json)
    
    return dataset_splits

dataset_splits = analyze_dataset_structure_and_splits(tree_json, filepath)

print(json.dumps(dataset_splits, indent=2))

{
  "train_test_val_split_exists": true,
  "splits": {
    "test": {
      "9": "/Users/natejly/Desktop/sorted_digits/test/9",
      "0": "/Users/natejly/Desktop/sorted_digits/test/0",
      "7": "/Users/natejly/Desktop/sorted_digits/test/7",
      "6": "/Users/natejly/Desktop/sorted_digits/test/6",
      "1": "/Users/natejly/Desktop/sorted_digits/test/1",
      "8": "/Users/natejly/Desktop/sorted_digits/test/8",
      "4": "/Users/natejly/Desktop/sorted_digits/test/4",
      "3": "/Users/natejly/Desktop/sorted_digits/test/3",
      "2": "/Users/natejly/Desktop/sorted_digits/test/2",
      "5": "/Users/natejly/Desktop/sorted_digits/test/5"
    },
    "train": {
      "9": "/Users/natejly/Desktop/sorted_digits/train/9",
      "0": "/Users/natejly/Desktop/sorted_digits/train/0",
      "7": "/Users/natejly/Desktop/sorted_digits/train/7",
      "6": "/Users/natejly/Desktop/sorted_digits/train/6",
      "1": "/Users/natejly/Desktop/sorted_digits/train/1",
      "8": "/Users/natejly/Desktop/

In [338]:
import os
import shutil
import random

def split_dataset(parent_folder, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    # Ensure ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1"

    # Create output folders
    for split in ["train", "val", "test"]:
        split_path = os.path.join(parent_folder, split)
        os.makedirs(split_path, exist_ok=True)

    # Loop through each class folder inside parent
    for class_name in os.listdir(parent_folder):
        class_path = os.path.join(parent_folder, class_name)
        
        # Skip the folders we just created
        if class_name in ["train", "val", "test"]:
            continue
        if not os.path.isdir(class_path):
            continue

        # Collect all files for this class
        files = os.listdir(class_path)
        random.shuffle(files)

        # Split indices
        n_total = len(files)
        n_train = int(train_ratio * n_total)
        n_val = int(val_ratio * n_total)

        train_files = files[:n_train]
        val_files = files[n_train:n_train + n_val]
        test_files = files[n_train + n_val:]

        # Function to move files into split folders
        def move_files(file_list, split):
            split_class_dir = os.path.join(parent_folder, split, class_name)
            os.makedirs(split_class_dir, exist_ok=True)
            for f in file_list:
                src = os.path.join(class_path, f)
                dst = os.path.join(split_class_dir, f)
                shutil.move(src, dst)

        # Move files
        move_files(train_files, "train")
        move_files(val_files, "val")
        move_files(test_files, "test")
    train_dir = os.path.join(parent_folder, "train")
    val_dir = os.path.join(parent_folder, "val")
    test_dir = os.path.join(parent_folder, "test")

    return train_dir, val_dir, test_dir



In [None]:
def gen_splits(ds_split_tree):
    #TODO: Implement train test split
    data = json
    train_folders = []
    val_folders = []
    test_folders = []
    # organize the data into splits
    has_splits = data["train_test_val_split_exists"]
    if has_splits:
        if "train" in data["splits"]:
            for value in data["splits"]["train"].values():
                train_folders.append(value)
        if "val" in data["splits"]:
            for value in data["splits"]["val"].values():
                val_folders.append(value)
        if "test" in data["splits"]:
            for value in data["splits"]["test"].values():
                test_folders.append(value)
    else:
        return split_dataset(filepath)

    return train_folders, val_folders, test_folders

train_folders, val_folders, test_folders = gen_splits(dataset_splits)


In [340]:
import tensorflow as tf
train_dir = os.path.dirname(train_folders[0])
val_dir = os.path.dirname(val_folders[0]) 
test_dir = os.path.dirname(test_folders[0]) 
print("train dir:", train_dir)
print("val dir:", val_dir)
print("test dir:", test_dir)

train dir: /Users/natejly/Desktop/sorted_digits/train
val dir: /Users/natejly/Desktop/sorted_digits/val
test dir: /Users/natejly/Desktop/sorted_digits/test


In [341]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications.efficientnet import preprocess_input

def get_classes_from_train(train_dir):
    return len([d.name for d in os.scandir(train_dir) if d.is_dir()])
#TODO: figure out image scaling
def bucket_dims(img_dims):
    if img_dims[0] < 128:
        dim = max(32, max(img_dims[0], img_dims[1]))
        return (dim, dim)
    if img_dims[0] < 256:
        return (128, 128)
    elif img_dims[0] < 512:
        return (256, 256)
    else:
        return (512, 512)

# -----------------------------
# CONFIGURATION
# -----------------------------
IMG_SIZE = bucket_dims(img_dims)
BATCH_SIZE = 64

NUM_CLASSES = get_classes_from_train(train_dir)
print("Detected classes:", NUM_CLASSES)
print("Using image size:", IMG_SIZE)

BASE_MODEL_NAME = "EfficientNetB0"
INITIAL_LEARNING_RATE = 1e-3  # Higher LR for frozen backbone
FINE_TUNE_LEARNING_RATE = 1e-5  # Lower LR for fine-tuning

# -----------------------------
# FUNCTION: Load Base Model
# -----------------------------
def get_base_model(model_name, input_shape, weights="imagenet"):
    """
    Returns a pretrained model (EfficientNet variants) without the top layer.
    """
    ModelClass = getattr(tf.keras.applications, model_name)
    base_model = ModelClass(
        include_top=False,
        weights="imagenet",
        input_shape=input_shape + (3,)
    )
    base_model.trainable = False  # Start with frozen backbone
    return base_model

def build_model(base_model, num_classes):
    inputs = tf.keras.Input(shape=IMG_SIZE + (3,))
    x = base_model(inputs, training=False)         # backbone
    x = layers.GlobalAveragePooling2D()(x)         # flatten features
    x = layers.Dropout(0.2)(x)                     # regularization
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    return models.Model(inputs, outputs)

# -----------------------------
# DATA PREPROCESSING
# -----------------------------
AUTOTUNE = tf.data.AUTOTUNE

def preprocess(image, label):
    image = tf.cast(image, tf.float32)
    image = preprocess_input(image)   # scale to [-1,1]
    return image, label

# Data augmentation for training (optional but recommended)
def augment(image, label):
    return image, label

train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir, image_size=IMG_SIZE, batch_size=BATCH_SIZE, shuffle=True
).map(preprocess, num_parallel_calls=AUTOTUNE).map(augment, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

val_ds = tf.keras.utils.image_dataset_from_directory(
    val_dir, image_size=IMG_SIZE, batch_size=BATCH_SIZE, shuffle=False
).map(preprocess, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

test_ds = tf.keras.utils.image_dataset_from_directory(
    test_dir, image_size=IMG_SIZE, batch_size=BATCH_SIZE, shuffle=False
).map(preprocess, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

# -----------------------------
# BUILD AND COMPILE MODEL
# -----------------------------
base_model = get_base_model(BASE_MODEL_NAME, IMG_SIZE)
model = build_model(base_model, NUM_CLASSES)

print(f"Total parameters: {model.count_params():,}")
print(f"Trainable parameters: {sum([tf.keras.backend.count_params(w) for w in model.trainable_weights]):,}")
# -----------------------------
# STAGE 1: Train with frozen backbone
# -----------------------------
base_model.trainable = False
print("\n=== STAGE 1: Training with frozen backbone ===")

model.compile(
    optimizer=tf.keras.optimizers.Adam(INITIAL_LEARNING_RATE),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks_stage1 = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=2,
        restore_best_weights=True
    )
]

history_stage1 = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,   # effectively "infinite"
    callbacks=callbacks_stage1,
    verbose=1
)

# -----------------------------
# STAGE 2: Fine-tune with unfrozen backbone
# -----------------------------


Detected classes: 10
Using image size: (32, 32)
Found 6996 files belonging to 10 classes.
Found 1496 files belonging to 10 classes.
Found 1508 files belonging to 10 classes.
Total parameters: 4,062,381
Trainable parameters: 12,810

=== STAGE 1: Training with frozen backbone ===
Epoch 1/10
[1m110/110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 41ms/step - accuracy: 0.3063 - loss: 2.0233 - val_accuracy: 0.6604 - val_loss: 1.4584
Epoch 2/10
[1m110/110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 35ms/step - accuracy: 0.6026 - loss: 1.4236 - val_accuracy: 0.6878 - val_loss: 1.1945
Epoch 3/10
[1m110/110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 37ms/step - accuracy: 0.6704 - loss: 1.1939 - val_accuracy: 0.7447 - val_loss: 1.0432
Epoch 4/10
[1m110/110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 31ms/step - accuracy: 0.6976 - loss: 1.0740 - val_accuracy: 0.7600 - val_loss: 0.9540
Epoch 5/10
[1m110/110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1