In [7]:
import os
import shutil
from glob import glob
import random
import tkinter as tk
from tkinter import messagebox, filedialog
from tkinter import ttk
import threading

# Set a seed for reproducibility
random.seed(42)

# Global variable to control the stopping of the process
stop_process = False

def split_dataset(base_path, progress_callback):
    global stop_process
    parent_path = os.path.dirname(base_path)
    new_base_path = os.path.join(parent_path, 'dl_data')

    # Directories for new train, validation, and test
    new_train_dir = os.path.join(new_base_path, 'train')
    new_valid_dir = os.path.join(new_base_path, 'valid')
    new_test_dir = os.path.join(new_base_path, 'test')

    # Ensure the new directories exist
    os.makedirs(new_train_dir, exist_ok=True)
    os.makedirs(new_valid_dir, exist_ok=True)
    os.makedirs(new_test_dir, exist_ok=True)

    # Define split ratios
    train_split = 0.6  # 60% of the total data
    valid_split = 0.2  # 20% of the total data
    test_split = 0.2   # 20% of the total data

    # Function to copy files with error handling
    def copy_files(files, dest):
        global stop_process
        for file in files:
            if stop_process:
                return
            try:
                if os.path.exists(file):
                    shutil.copy(file, dest)
                    progress_callback()
                else:
                    print(f"File does not exist: {file}")
            except PermissionError as e:
                print(f"PermissionError: {e}")
            except Exception as e:
                print(f"Error copying file {file} to {dest}: {e}")

    # Function to count files in a directory
    def count_files(directory):
        return sum([len(files) for r, d, files in os.walk(directory)])

    # Check if original train and test folders exist
    original_train_dir = os.path.join(base_path, 'train')
    original_test_dir = os.path.join(base_path, 'test')

    total_files = 0
    train_file_count = 0
    valid_file_count = 0
    test_file_count = 0

    if os.path.exists(original_train_dir) and os.path.exists(original_test_dir):
        # Get all subdirectories in the original train folder
        train_subdirs = [d for d in glob(os.path.join(original_train_dir, '*')) if os.path.isdir(d)]

        for subdir in train_subdirs:
            if stop_process:
                return train_file_count, valid_file_count, test_file_count, total_files
            # Get the subdirectory name (e.g., '0', '1', ..., '9')
            subdir_name = os.path.basename(subdir)
            
            # Create corresponding subdirectory in the new train and valid directories
            new_train_subdir = os.path.join(new_train_dir, subdir_name)
            new_valid_subdir = os.path.join(new_valid_dir, subdir_name)
            os.makedirs(new_train_subdir, exist_ok=True)
            os.makedirs(new_valid_subdir, exist_ok=True)
            
            # Get all files from the subdirectory
            files = glob(os.path.join(subdir, '*'))
            total_files += len(files)
            
            # Shuffle files
            random.shuffle(files)
            
            # Calculate split indices
            train_end = int(len(files) * 0.75)  # 75% of the original train folder (which is 60% of the total data)
            
            # Split files
            train_files = files[:train_end]
            valid_files = files[train_end:]
            
            # Copy files to respective directories
            copy_files(train_files, new_train_subdir)
            copy_files(valid_files, new_valid_subdir)

        # Copy the original test folder to the new test directory
        test_subdirs = [d for d in glob(os.path.join(original_test_dir, '*')) if os.path.isdir(d)]
        for subdir in test_subdirs:
            if stop_process:
                return train_file_count, valid_file_count, test_file_count, total_files
            subdir_name = os.path.basename(subdir)
            new_test_subdir = os.path.join(new_test_dir, subdir_name)
            os.makedirs(new_test_subdir, exist_ok=True)
            files = glob(os.path.join(subdir, '*'))
            total_files += len(files)
            copy_files(files, new_test_subdir)

    else:
        # Get all files from the base path
        files = glob(os.path.join(base_path, '*'))
        total_files = len(files)
        
        # Shuffle files
        random.shuffle(files)
        
        # Calculate split indices
        train_end = int(len(files) * train_split)
        valid_end = train_end + int(len(files) * valid_split)
        
        # Split files
        train_files = files[:train_end]
        valid_files = files[train_end:valid_end]
        test_files = files[valid_end:]
        
        # Copy files to respective directories
        copy_files(train_files, new_train_dir)
        copy_files(valid_files, new_valid_dir)
        copy_files(test_files, new_test_dir)

    # Count the number of files in each new directory
    train_file_count = count_files(new_train_dir)
    valid_file_count = count_files(new_valid_dir)
    test_file_count = count_files(new_test_dir)

    return train_file_count, valid_file_count, test_file_count, total_files

def select_folder():
    folder_selected = filedialog.askdirectory()
    if folder_selected:
        entry.delete(0, tk.END)
        entry.insert(0, folder_selected)
        check_datasets(folder_selected)

def check_datasets(base_path):
    original_train_dir = os.path.join(base_path, 'train')
    original_test_dir = os.path.join(base_path, 'test')

    if os.path.exists(original_train_dir) and os.path.exists(original_test_dir):
        status_label.config(text="Train set and test set found. Start splitting.")
    else:
        status_label.config(text="Creating train set, valid set, and test set. Start splitting.")

def start_processing():
    global stop_process
    stop_process = False
    base_path = entry.get()
    if not os.path.exists(base_path):
        messagebox.showerror("Error", "The specified path does not exist.")
        return

    progress_label.config(text="Processing...")
    root.update_idletasks()

    # Run the dataset splitting in a separate thread
    threading.Thread(target=process_dataset, args=(base_path,)).start()

def process_dataset(base_path):
    global stop_process

    # Get the total number of files to process
    total_files = sum([len(files) for r, d, files in os.walk(base_path)])
    progress_bar['maximum'] = total_files

    def progress_callback():
        progress_bar.step(1)
        root.update_idletasks()

    train_count, valid_count, test_count, total_files = split_dataset(base_path, progress_callback)

    if not stop_process:
        progress_label.config(text="Processing complete.")
        messagebox.showinfo("Done", "Dataset split complete.")
        
        total_count = train_count + valid_count + test_count
        total_label.config(text=f"Total images: {total_count}")
        train_label.config(text=f"Train images: {train_count}")
        valid_label.config(text=f"Valid images: {valid_count}")
        test_label.config(text=f"Test images: {test_count}")
    else:
        progress_label.config(text="Processing stopped.")

def stop_processing():
    global stop_process
    stop_process = True
    progress_label.config(text="Processing stopped.")

# Create the main window
root = tk.Tk()
root.title("Dataset Splitter")

# Create and place the widgets
tk.Label(root, text="Base Path:").grid(row=0, column=0, padx=10, pady=10)
entry = tk.Entry(root, width=50)
entry.grid(row=0, column=1, padx=10, pady=10)

browse_button = tk.Button(root, text="Browse", command=select_folder)
browse_button.grid(row=0, column=2, padx=10, pady=10)

start_button = tk.Button(root, text="Start", command=start_processing)
start_button.grid(row=1, column=0, columnspan=2, pady=10)

stop_button = tk.Button(root, text="Stop", command=stop_processing)
stop_button.grid(row=1, column=2, pady=10)

status_label = tk.Label(root, text="")
status_label.grid(row=2, column=0, columnspan=3, pady=10)

progress_label = tk.Label(root, text="")
progress_label.grid(row=3, column=0, columnspan=3, pady=10)

progress_bar = ttk.Progressbar(root, orient="horizontal", length=400, mode="determinate")
progress_bar.grid(row=4, column=0, columnspan=3, pady=10)

total_label = tk.Label(root, text="Total images: 0")
total_label.grid(row=5, column=0, columnspan=3, pady=5)

train_label = tk.Label(root, text="Train images: 0")
train_label.grid(row=6, column=0, columnspan=3, pady=5)

valid_label = tk.Label(root, text="Valid images: 0")
valid_label.grid(row=7, column=0, columnspan=3, pady=5)

test_label = tk.Label(root, text="Test images: 0")
test_label.grid(row=8, column=0, columnspan=3, pady=5)

# Run the application
root.mainloop()