In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
from collections import Counter

In [None]:
OLD_TRAIN_DIR = "./data/Training"
OLD_TEST_DIR = "./data/Testing"
NEW_DIR = "./new_data"

CONTAIN_CLASSES = ["glioma_tumor", "meningioma_tumor", "no_tumor", "pituitary_tumor"]

NEW_TRAIN_SPLIT = 0.7
NEW_VAL_SPLIT = 0.15
NEW_TEST_SPLIT = 0.15

In [None]:
class DatasetDivider:
    def __init__(self):
        self.dataset = None
        self.split_dict = None

    def collect_all_images(self):
        all_imgs = []

        for class_name in CONTAIN_CLASSES:
            train_class_dir = os.path.join(OLD_TRAIN_DIR, class_name)
            if os.path.exists(train_class_dir):
                for filename in os.listdir(train_class_dir):
                    if filename.lower().endswith((".jpg", ".jpeg")):
                        img_path = os.path.join(train_class_dir, filename)
                        all_imgs.append((img_path, class_name))

            test_class_dir = os.path.join(OLD_TEST_DIR, class_name)
            if os.path.exists(test_class_dir):
                for filename in os.listdir(test_class_dir):
                    if filename.lower().endswith((".jpg", ".jpeg")):
                        img_path = os.path.join(test_class_dir, filename)
                        all_imgs.append((img_path, class_name))

        self.dataset = all_imgs
        return self.dataset

    def create_directory_structure(self):
        for set in ["Training", "Validation", "Testing"]:
            for class_name in CONTAIN_CLASSES:
                dir_path = os.path.join(NEW_DIR, set, class_name)
                os.makedirs(dir_path, exist_ok=True)

    def split_and_copy(self):
        if self.dataset is None:
            raise ValueError("Dataset is empty. Call collect_all_images first.")

        paths = []
        class_names = []

        for path, _ in self.dataset:
            paths.append(path)

        for _, class_name in self.dataset:
            class_names.append(class_name)

        train_paths, temp_paths, train_classes, temp_classes = train_test_split(
            paths,
            class_names,
            train_size=NEW_TRAIN_SPLIT,
            stratify=class_names,
            random_state=42,
        )

        val_ratio_adjusted = NEW_VAL_SPLIT / (NEW_VAL_SPLIT + NEW_TEST_SPLIT)

        val_paths, test_paths, val_classes, test_classes = train_test_split(
            temp_paths,
            temp_classes,
            train_size=val_ratio_adjusted,
            stratify=temp_classes,
            random_state=42,
        )

        subsets = {
            "Training": (train_paths, train_classes),
            "Validation": (val_paths, val_classes),
            "Testing": (test_paths, test_classes),
        }

        for subset, (s_path, s_class) in subsets.items():
            for current_path, current_class in zip(s_path, s_class):
                filename = os.path.basename(current_path)
                dst = os.path.join(NEW_DIR, subset, current_class, filename)
                if os.path.exists(dst):
                    dst = os.path.join(
                        NEW_DIR, subset, current_class, f"dup_{filename}"
                    )

                shutil.copy(current_path, dst)

        self.split_dict = {
            "train": (train_paths, train_classes),
            "val": (val_paths, val_classes),
            "test": (test_paths, test_classes),
        }
        return self.split_dict

    def print_statistics(self):
        if self.split_dict is None:
            raise ValueError("No split data. Call split_and_copy first.")

        print("STATS FOR DATASET AFTER RANDOM DIVISION")
        for subset, (s_path, s_class) in self.split_dict.items():
            print(f"{subset.upper()}: {len(s_path)} images total")
            counts = Counter(s_class)
            for class_name in CONTAIN_CLASSES:
                count = counts.get(class_name, 0)
                percentage = count / len(s_path) * 100

                print(f"\t{class_name}: {count} ({percentage:.1f}%)")

    def run(self):
        """Execute entire pipeline"""
        self.collect_all_images()
        self.create_directory_structure()
        self.split_and_copy()
        self.print_statistics()
        return self.split_dict