In [2]:
import os
import pickle
import numpy as np
from PIL import Image
import struct
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk
import threading

def load_c10_meta(filename):
    with open(filename, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    return [label.decode('utf-8') for label in dict[b'label_names']]

def load_cifar10_batch(filename):
    with open(filename, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    images = dict[b'data']
    labels = dict[b'labels']
    images = images.reshape(len(images), 3, 32, 32).transpose(0, 2, 3, 1)
    return images, labels

def load_cifar100_meta(filename):
    with open(filename, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    return [label.decode('utf-8') for label in dict[b'fine_label_names']]

def load_cifar100_batch(filename):
    with open(filename, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    images = dict[b'data']
    labels = dict[b'fine_labels']
    images = images.reshape(len(images), 3, 32, 32).transpose(0, 2, 3, 1)
    return images, labels

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

def save_images(images, labels, class_names, path='images', prefix='image', progress_callback=None, stop_event=None):
    if not os.path.exists(path):
        os.makedirs(path)
    total = len(images)
    for i, (image, label) in enumerate(zip(images, labels)):
        if stop_event and stop_event.is_set():
            break
        class_name = class_names[label] if class_names else str(label)
        class_path = os.path.join(path, class_name)
        if not os.path.exists(class_path):
            os.makedirs(class_path)
        img = Image.fromarray(image)
        img.save(os.path.join(class_path, f'{prefix}_{i}.png'))
        if progress_callback:
            progress_callback(i + 1, total)

def check_cifar10_folder(folder_path):
    batch_files = [os.path.join(folder_path, f'data_batch_{i}') for i in range(1, 6)]
    test_file = os.path.join(folder_path, 'test_batch')
    meta_file = os.path.join(folder_path, 'batches.meta')
    if all(os.path.exists(file) for file in batch_files) and os.path.exists(test_file) and os.path.exists(meta_file):
        return batch_files, test_file, meta_file
    else:
        return None, None, None

def check_cifar100_folder(folder_path):
    train_file = os.path.join(folder_path, 'train')
    test_file = os.path.join(folder_path, 'test')
    meta_file = os.path.join(folder_path, 'meta')
    if os.path.exists(train_file) and os.path.exists(test_file) and os.path.exists(meta_file):
        return train_file, test_file, meta_file
    else:
        return None, None, None

def check_mnist_folder(folder_path):
    train_images_file = os.path.join(folder_path, 'train-images.idx3-ubyte')
    train_labels_file = os.path.join(folder_path, 'train-labels.idx1-ubyte')
    test_images_file = os.path.join(folder_path, 't10k-images.idx3-ubyte')
    test_labels_file = os.path.join(folder_path, 't10k-labels.idx1-ubyte')
    if all(os.path.exists(file) for file in [train_images_file, train_labels_file, test_images_file, test_labels_file]):
        return train_images_file, train_labels_file, test_images_file, test_labels_file
    else:
        return None, None, None, None

def process_cifar10_folder(cifar10_folder, progress_callback=None, stop_event=None):
    batch_files, test_file, meta_file = check_cifar10_folder(cifar10_folder)

    if batch_files and test_file and meta_file:
        class_names = load_c10_meta(meta_file)
        
        all_train_images = []
        all_train_labels = []
        for batch_file in batch_files:
            if stop_event and stop_event.is_set():
                break
            images, labels = load_cifar10_batch(batch_file)
            all_train_images.append(images)
            all_train_labels.append(labels)
        
        if stop_event and stop_event.is_set():
            return
        
        all_train_images = np.concatenate(all_train_images)
        all_train_labels = np.concatenate(all_train_labels)
        
        train_output_path = os.path.join(cifar10_folder, 'images', 'train')
        save_images(all_train_images, all_train_labels, class_names, path=train_output_path, progress_callback=progress_callback, stop_event=stop_event)
        
        test_images, test_labels = load_cifar10_batch(test_file)
        test_output_path = os.path.join(cifar10_folder, 'images', 'test')
        save_images(test_images, test_labels, class_names, path=test_output_path, progress_callback=progress_callback, stop_event=stop_event)
    else:
        messagebox.showerror("Error", "CIFAR-10 files not found in the specified folder.")

def process_cifar100_folder(cifar100_folder, progress_callback=None, stop_event=None):
    train_file, test_file, meta_file = check_cifar100_folder(cifar100_folder)

    if train_file and test_file and meta_file:
        class_names = load_cifar100_meta(meta_file)
        
        train_images, train_labels = load_cifar100_batch(train_file)
        train_output_path = os.path.join(cifar100_folder, 'images', 'train')
        save_images(train_images, train_labels, class_names, path=train_output_path, progress_callback=progress_callback, stop_event=stop_event)
        
        test_images, test_labels = load_cifar100_batch(test_file)
        test_output_path = os.path.join(cifar100_folder, 'images', 'test')
        save_images(test_images, test_labels, class_names, path=test_output_path, progress_callback=progress_callback, stop_event=stop_event)
    else:
        messagebox.showerror("Error", "CIFAR-100 files not found in the specified folder.")

def process_mnist_folder(mnist_folder, progress_callback=None, stop_event=None):
    train_images_path, train_labels_path, test_images_path, test_labels_path = check_mnist_folder(mnist_folder)

    if train_images_path and train_labels_path and test_images_path and test_labels_path:
        train_images = read_idx(train_images_path)
        train_labels = read_idx(train_labels_path)
        train_output_path = os.path.join(mnist_folder, 'images', 'train')
        save_images(train_images, train_labels, class_names=None, path=train_output_path, progress_callback=progress_callback, stop_event=stop_event)
        
        test_images = read_idx(test_images_path)
        test_labels = read_idx(test_labels_path)
        test_output_path = os.path.join(mnist_folder, 'images', 'test')
        save_images(test_images, test_labels, class_names=None, path=test_output_path, progress_callback=progress_callback, stop_event=stop_event)
    else:
        messagebox.showerror("Error", "MNIST files not found in the specified folder.")

def process_dataset_folder(dataset_folder, progress_callback=None, stop_event=None):
    if check_cifar10_folder(dataset_folder)[0]:
        status_label.config(text="Processing CIFAR-10...")
        process_cifar10_folder(dataset_folder, progress_callback, stop_event)
    elif check_cifar100_folder(dataset_folder)[0]:
        status_label.config(text="Processing CIFAR-100...")
        process_cifar100_folder(dataset_folder, progress_callback, stop_event)
    elif check_mnist_folder(dataset_folder)[0]:
        status_label.config(text="Processing MNIST...")
        process_mnist_folder(dataset_folder, progress_callback, stop_event)
    else:
        messagebox.showerror("Error", "Dataset files not found in the specified folder.")

def select_folder():
    folder_selected = filedialog.askdirectory()
    if folder_selected:
        folder_path.set(folder_selected)

def start_processing():
    global stop_event
    stop_event = threading.Event()
    folder = folder_path.get()
    if folder:
        status_label.config(text="Processing...")
        progress_bar.start()
        process_thread = threading.Thread(target=process_and_notify, args=(folder,))
        process_thread.start()
    else:
        messagebox.showwarning("Warning", "Please select a folder first.")

def process_and_notify(folder):
    process_dataset_folder(folder, update_progress, stop_event)
    if not stop_event.is_set():
        status_label.config(text="Processing completed!")
        progress_bar['value'] = 100  # Ensure the progress bar is full
        progress_bar.stop()  # Stop the progress bar before showing the message box
        messagebox.showinfo("Info", "Processing completed!")
    else:
        status_label.config(text="Processing stopped.")
    progress_bar.stop()
    progress_bar['value'] = 0

def update_progress(current, total):
    progress = (current / total) * 100
    progress_bar['value'] = progress

def stop_loading():
    global stop_event
    if stop_event:
        stop_event.set()
    status_label.config(text="Processing stopped.")
    progress_bar.stop()
    progress_bar['value'] = 0

# Create the main window
root = tk.Tk()
root.title("Dataset Processor")

# Folder path variable
folder_path = tk.StringVar()

# Create and place widgets
tk.Label(root, text="Select Dataset Folder:").grid(row=0, column=0, padx=10, pady=10)
tk.Entry(root, textvariable=folder_path, width=50).grid(row=0, column=1, padx=10, pady=10)
tk.Button(root, text="Browse", command=select_folder).grid(row=0, column=2, padx=10, pady=10)
tk.Button(root, text="Process", command=start_processing).grid(row=0, column=3, padx=10, pady=10)
status_label = tk.Label(root, text="")
status_label.grid(row=1, column=0, columnspan=4, pady=5)
progress_bar = ttk.Progressbar(root, orient="horizontal", length=400, mode="determinate")
progress_bar.grid(row=2, column=0, columnspan=4, pady=5)
tk.Button(root, text="Stop", command=stop_loading).grid(row=3, column=0, columnspan=4, pady=10)

# Run the application
root.mainloop()