In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import customtkinter as ctk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from skimage.feature import local_binary_pattern
import threading
import time
import traceback

ctk.set_appearance_mode("System")
ctk.set_default_color_theme("blue")

class LeafClassifierGUI(ctk.CTk):
    def ___init___(self):
        super().___init___()
        self.title("Leaf Classification System")
        self.geometry("1400x800")
        self.resizable(True, True)
        self.protocol("WM_DELETE_WINDOW", self.on_closing)

        # History file
        self.history_file = "history.txt"
        self.history = self.load_history()
        
        # Store processed image and features
        self.processed_image = None
        self.extracted_features = None
        self.original_image_path = None
        self.prediction = None
        
        # Zoom and pan variables
        self.zoom_factor = 1.0
        self.pan_x = 0
        self.pan_y = 0
        
        # Batch processing queue
        self.batch_queue = []
        self.batch_processing = False
        self.batch_paused = False
        
        # Image cache for performance
        self.image_cache = {}
        
        # Main container
        self.grid_columnconfigure(1, weight=1)
        self.grid_columnconfigure(2, weight=1)
        self.grid_rowconfigure(1, weight=1)
        self.grid_rowconfigure(2, weight=1)

        # Create placeholder frame
        self.placeholder_frame = ctk.CTkFrame(self)
        self.placeholder_frame.grid(row=0, column=0, columnspan=3, rowspan=3, sticky="nsew")
        ctk.CTkLabel(self.placeholder_frame, 
                     text="Use 'Upload Image' or 'Batch Process' in the sidebar",
                     font=ctk.CTkFont(size=14)).pack(expand=True)

        # Create widgets
        self._create_menu_bar()
        self._create_sidebar()
        self._create_main_frame()
        self._create_result_frame()
        self._create_history_frame()

    def on_closing(self):
        self.batch_processing = False
        self.destroy()

    def _create_menu_bar(self):
        self.menu_bar = ctk.CTkFrame(self, height=30)
        self.menu_bar.grid(row=0, column=0, columnspan=3, sticky="ew")
        
        file_menu_btn = ctk.CTkButton(self.menu_bar, text="File", width=80, command=self._show_file_menu)
        file_menu_btn.grid(row=0, column=0, padx=5, pady=5)
        
        view_menu_btn = ctk.CTkButton(self.menu_bar, text="View", width=80, command=self._toggle_theme)
        view_menu_btn.grid(row=0, column=1, padx=5, pady=5)
        
        help_menu_btn = ctk.CTkButton(self.menu_bar, text="Help", width=80, command=self._show_help)
        help_menu_btn.grid(row=0, column=2, padx=5, pady=5)
        
        export_viz_btn = ctk.CTkButton(self.menu_bar, text="Export Viz", width=80, command=self.export_visualization)
        export_viz_btn.grid(row=0, column=3, padx=5, pady=5)

    def _show_file_menu(self):
        messagebox.showinfo("File Menu", "Options: Upload Image, Batch Process, Save Results, Export Visualization")

    def _toggle_theme(self):
        current_mode = ctk.get_appearance_mode()
        new_mode = "Light" if current_mode == "Dark" else "Dark"
        ctk.set_appearance_mode(new_mode)
        self._update_status(f"Switched to {new_mode} mode", "green")
        self.update()

    def _show_help(self):
        messagebox.showinfo("Help", "Upload an image or folder using the sidebar buttons.\n"
                                   "Select feature types to display.\n"
                                   "Zoom/pan images with mouse.\n"
                                   "Click 'Reset View' to restore original size.")

    def _create_sidebar(self):
        self.sidebar_frame = ctk.CTkFrame(self, width=250)
        self.sidebar_frame.grid(row=1, column=0, padx=10, pady=10, sticky="ns")
        self.sidebar_frame.grid_propagate(False)

        sidebar_label = ctk.CTkLabel(self.sidebar_frame, text="Leaf Classifier", 
                                     font=ctk.CTkFont(size=18, weight="bold"))
        sidebar_label.pack(pady=(10, 5))

        upload_frame = ctk.CTkFrame(self.sidebar_frame)
        upload_frame.pack(fill="x", padx=5, pady=5)
        
        self.upload_btn = ctk.CTkButton(upload_frame, text="Upload Image", 
                                       command=self.upload_image, 
                                       fg_color="#4CAF50", hover_color="#45a049")
        self.upload_btn.pack(pady=5, fill="x")
        
        self.batch_btn = ctk.CTkButton(upload_frame, text="Batch Process", 
                                      command=self.batch_process, 
                                      fg_color="#2196F3", hover_color="#1976D2")
        self.batch_btn.pack(pady=5, fill="x")
        
        self.pause_btn = ctk.CTkButton(upload_frame, text="Pause Batch", 
                                      command=self.pause_batch, state="disabled",
                                      fg_color="#FF5722", hover_color="#E64A19")
        self.pause_btn.pack(pady=5, fill="x")
        
        self.resume_btn = ctk.CTkButton(upload_frame, text="Resume Batch", 
                                       command=self.resume_batch, state="disabled",
                                       fg_color="#4CAF50", hover_color="#45a049")
        self.resume_btn.pack(pady=5, fill="x")

        features_frame = ctk.CTkFrame(self.sidebar_frame)
        features_frame.pack(fill="x", padx=5, pady=5)
        
        features_label = ctk.CTkLabel(features_frame, text="Feature Selection", 
                                     font=ctk.CTkFont(size=14, weight="bold"))
        features_label.pack(pady=(0, 5))
        
        self.shape_var = ctk.BooleanVar(value=True)
        self.color_var = ctk.BooleanVar(value=True)
        self.texture_var = ctk.BooleanVar(value=True)
        
        self.shape_checkbox = ctk.CTkCheckBox(features_frame, text="Shape", 
                                             variable=self.shape_var, 
                                             command=self.update_features_display)
        self.shape_checkbox.pack(pady=5, padx=10, anchor="w")
        
        self.color_checkbox = ctk.CTkCheckBox(features_frame, text="Color", 
                                             variable=self.color_var, 
                                             command=self.update_features_display)
        self.color_checkbox.pack(pady=5, padx=10, anchor="w")
        
        self.texture_checkbox = ctk.CTkCheckBox(features_frame, text="Texture", 
                                               variable=self.texture_var, 
                                               command=self.update_features_display)
        self.texture_checkbox.pack(pady=5, padx=10, anchor="w")

        processing_frame = ctk.CTkFrame(self.sidebar_frame)
        processing_frame.pack(fill="x", padx=5, pady=5)
        
        processing_label = ctk.CTkLabel(processing_frame, text="Processing Options", 
                                       font=ctk.CTkFont(size=14, weight="bold"))
        processing_label.pack(pady=(0, 5))
        
        self.show_steps_var = ctk.BooleanVar(value=True)
        self.show_steps_checkbox = ctk.CTkCheckBox(processing_frame, 
                                                 text="Show Steps", 
                                                 variable=self.show_steps_var)
        self.show_steps_checkbox.pack(pady=5, padx=10, anchor="w")
        
        self.progress_bar = ctk.CTkProgressBar(self.sidebar_frame)
        self.progress_bar.pack(pady=5, padx=10, fill="x")
        self.progress_bar.set(0)

        self.status_label = ctk.CTkLabel(self.sidebar_frame, text="Status: Ready", 
                                        text_color="green", font=ctk.CTkFont(size=12))
        self.status_label.pack(pady=5)
        
        self.save_btn = ctk.CTkButton(self.sidebar_frame, text="Save Results", 
                                     command=self.save_results, state="disabled",
                                     fg_color="#FF9800", hover_color="#F57C00")
        self.save_btn.pack(pady=5, padx=10, fill="x")
        
        self.reset_view_btn = ctk.CTkButton(self.sidebar_frame, text="Reset View", 
                                           command=self.reset_view,
                                           fg_color="#9C27B0", hover_color="#7B1FA2")
        self.reset_view_btn.pack(pady=5, padx=10, fill="x")

    def reset_view(self):
        self.zoom_factor = 1.0
        self.pan_x = 0
        self.pan_y = 0
        if self.original_image_path:
            self._display_original_image(self.original_image_path)
        if self.processed_image is not None:
            self._display_processed_image(self.processed_image)
        self._update_status("View reset to original size", "green")
        self.update()

    def _create_main_frame(self):
        self.main_frame = ctk.CTkFrame(self)
        self.main_frame.grid(row=1, column=1, padx=10, pady=10, sticky="nsew")
        self.main_frame.grid_columnconfigure(0, weight=1)
        self.main_frame.grid_rowconfigure((0, 1), weight=1)
        
        self.orig_image_frame = ctk.CTkFrame(self.main_frame)
        self.orig_image_frame.grid(row=0, column=0, padx=5, pady=5, sticky="nsew")
        
        self.orig_image_label = ctk.CTkLabel(self.orig_image_frame, text="Original Image", 
                                            font=ctk.CTkFont(size=14, weight="bold"))
        self.orig_image_label.pack(pady=5)
        
        self.orig_image_display = ctk.CTkLabel(self.orig_image_frame, text="Upload an image",
                                              font=ctk.CTkFont(size=12))
        self.orig_image_display.pack(expand=True, fill="both", padx=5, pady=5)
        self._bind_image_interactions(self.orig_image_display)
        
        self.proc_image_frame = ctk.CTkFrame(self.main_frame)
        self.proc_image_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew")
        
        self.proc_image_label = ctk.CTkLabel(self.proc_image_frame, text="Processed Image", 
                                            font=ctk.CTkFont(size=14, weight="bold"))
        self.proc_image_label.pack(pady=5)
        
        self.proc_image_display = ctk.CTkLabel(self.proc_image_frame, 
                                              text="Image will be processed here",
                                              font=ctk.CTkFont(size=12))
        self.proc_image_display.pack(expand=True, fill="both", padx=5, pady=5)
        self._bind_image_interactions(self.proc_image_display)

    def _bind_image_interactions(self, label):
        label.bind("<MouseWheel>", self._zoom_image)
        label.bind("<ButtonPress-1>", self._start_pan)
        label.bind("<B1-Motion>", self._pan_image)

    def _zoom_image(self, event):
        if event.delta > 0:
            self.zoom_factor *= 1.1
        else:
            self.zoom_factor /= 1.1
        self.zoom_factor = max(0.5, min(self.zoom_factor, 3.0))
        self._update_image_display()
        self.update()

    def _start_pan(self, event):
        self.pan_start_x = event.x
        self.pan_start_y = event.y

    def _pan_image(self, event):
        self.pan_x += event.x - self.pan_start_x
        self.pan_y += event.y - self.pan_start_y
        self.pan_start_x = event.x
        self.pan_start_y = event.y
        self._update_image_display()
        self.update()

    def _update_image_display(self):
        for display, image in [(self.orig_image_display, self.original_image_path), 
                              (self.proc_image_display, self.processed_image)]:
            if image is None:
                continue
            cache_key = (id(image), self.zoom_factor, self.pan_x, self.pan_y)
            if cache_key in self.image_cache:
                ctk_img = self.image_cache[cache_key]
                display.configure(image=ctk_img, text="")
                display.image = ctk_img
                continue
                
            if isinstance(image, str):
                img = cv2.imread(image)
                if img is None:
                    print(f"Failed to load image: {image}")
                    continue
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            else:
                img = image
            if img is None:
                print("Image is None in _update_image_display")
                continue
                
            if self.zoom_factor == 1.0:
                img = self._resize_image_for_display(img, 400)
            else:
                h, w = img.shape[:2]
                new_size = int(w * self.zoom_factor), int(h * self.zoom_factor)
                img = cv2.resize(img, new_size)
                
            pil_img = Image.fromarray(img)
            ctk_img = ctk.CTkImage(light_image=pil_img, dark_image=pil_img, size=img.shape[:2][::-1])
            display.configure(image=ctk_img, text="")
            display.image = ctk_img
            self.image_cache[cache_key] = ctk_img
            
            if len(self.image_cache) > 10:
                self.image_cache.pop(next(iter(self.image_cache)))

    def _create_result_frame(self):
        self.result_frame = ctk.CTkFrame(self)
        self.result_frame.grid(row=1, column=2, padx=10, pady=10, sticky="nsew")
        self.result_frame.grid_columnconfigure(0, weight=1)
        
        self.result_header = ctk.CTkLabel(self.result_frame, text="Feature Analysis", 
                                         font=ctk.CTkFont(size=16, weight="bold"))
        self.result_header.pack(pady=5)
        
        self.feature_display = ctk.CTkScrollableFrame(self.result_frame)
        self.feature_display.pack(expand=True, fill="both", padx=5, pady=5)
        
        self.feature_label = ctk.CTkLabel(self.feature_display, 
                                         text="Features will be displayed here",
                                         font=ctk.CTkFont(size=12))
        self.feature_label.pack(pady=5)
        
        self.classification_frame = ctk.CTkFrame(self.result_frame)
        self.classification_frame.pack(fill="x", padx=5, pady=5)
        
        self.classification_header = ctk.CTkLabel(self.classification_frame, 
                                                text="Classification Result", 
                                                font=ctk.CTkFont(size=14, weight="bold"))
        self.classification_header.pack(pady=5)
        
        self.classification_label = ctk.CTkLabel(self.classification_frame, 
                                               text="No classification yet",
                                               font=ctk.CTkFont(size=12))
        self.classification_label.pack(pady=5)

    def _create_history_frame(self):
        self.history_frame = ctk.CTkFrame(self)
        self.history_frame.grid(row=2, column=0, columnspan=3, padx=10, pady=10, sticky="nsew")
        self.history_frame.grid_columnconfigure(0, weight=1)
        
        history_label = ctk.CTkLabel(self.history_frame, text="Processing History", 
                                    font=ctk.CTkFont(size=16, weight="bold"))
        history_label.pack(pady=5)
        
        self.history_display = ctk.CTkScrollableFrame(self.history_frame)
        self.history_display.pack(expand=True, fill="both", padx=5, pady=5)
        
        self._update_history_display()

    def _update_history_display(self):
        existing_buttons = {w.cget("text"): w for w in self.history_display.winfo_children() 
                           if isinstance(w, ctk.CTkButton)}
        
        if len(existing_buttons) != len(self.history):
            for widget in self.history_display.winfo_children():
                widget.destroy()
        
        if not self.history:
            ctk.CTkLabel(self.history_display, text="No history yet", 
                        font=ctk.CTkFont(size=12)).pack(pady=5)
            return
        
        for i, item in enumerate(self.history):
            file_path = item.get("file_path", "Unknown")
            pred = item.get("prediction", "Unknown")
            conf = item.get("confidence", 0.0)
            text = f"Image: {os.path.basename(file_path)} | Prediction: {pred} | Confidence: {conf:.2f}"
            
            if text in existing_buttons:
                btn = existing_buttons[text]
            else:
                btn = ctk.CTkButton(self.history_display, text=text, 
                                   command=lambda p=file_path: self._load_history_item(p),
                                   fg_color="#2196F3", hover_color="#1976D2")
                btn.pack(pady=5, padx=5, fill="x")
            
            btn.pack_forget()
            btn.pack(pady=5, padx=5, fill="x")
        self.update()

    def _load_history_item(self, file_path):
        if os.path.exists(file_path):
            self.original_image_path = file_path
            self._update_status("Loading history item...")
            self._display_original_image(file_path)
            threading.Thread(target=self._process_image, args=(file_path,), daemon=True).start()
        else:
            messagebox.showerror("Error", "File not found!")

    def upload_image(self):
        file_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
        if file_path:
            self.original_image_path = file_path
            self._update_status("Loading image...", "orange")
            self.progress_bar.set(0.1)
            self._display_original_image(file_path)
            threading.Thread(target=self._process_image, args=(file_path,), daemon=True).start()

    def _display_original_image(self, file_path):
        image = cv2.imread(file_path)
        if image is None:
            self._update_status(f"Error: Could not read image file {file_path}", "red")
            print(f"Failed to load original image: {file_path}")
            return
            
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_resized = self._resize_image_for_display(image_rgb, 400)
        
        pil_image = Image.fromarray(image_resized)
        ctk_image = ctk.CTkImage(light_image=pil_image, dark_image=pil_image, size=(400, 400))
        
        self.orig_image_display.configure(image=ctk_image, text="")
        self.orig_image_display.image = ctk_image
        self.zoom_factor = 1.0
        self.pan_x = 0
        self.pan_y = 0
        self.image_cache.clear()
        self.update()

    def _resize_image_for_display(self, image, max_size):
        if image is None:
            return None
        h, w = image.shape[:2]
        if h > w:
            new_h = max_size
            new_w = int(w * max_size / h)
        else:
            new_w = max_size
            new_h = int(h * max_size / w)
        return cv2.resize(image, (new_w, new_h))

    def _process_image(self, file_path):
        print(f"Starting _process_image for {file_path}")
        self._update_status("Processing image...", "orange")
        self.progress_bar.set(0.3)
        
        try:
            # Step 1: Segment the image
            print("Calling improved_leaf_segmentation")
            processed_image = self.improved_leaf_segmentation(file_path, 
                                                            show=self.show_steps_var.get())
            if processed_image is None:
                raise ValueError("Leaf segmentation returned None")
            print("Segmentation complete")
            self.processed_image = processed_image
            self.progress_bar.set(0.6)
            
            # Step 2: Display processed image
            print("Displaying processed image")
            self._display_processed_image(processed_image)
            
            # Step 3: Extract features
            print("Extracting features")
            self._update_status("Extracting features...", "orange")
            features = self.extract_features_from_image(processed_image)
            if features is None:
                raise ValueError("Feature extraction returned None")
            self.extracted_features = features
            self.progress_bar.set(0.8)
            
            # Step 4: Classify leaf
            print("Classifying leaf")
            self._update_status("Classifying leaf...", "orange")
            pred_label, confidence = self.classify_leaf(features)
            self.prediction = (pred_label, confidence)
            self.progress_bar.set(1.0)
            
            # Step 5: Update GUI
            print("Updating feature and classification displays")
            self.update_features_display()
            self.update_classification_display()
            
            self._update_status("Processing complete!", "green")
            self.save_btn.configure(state="normal")
            
            # Step 6: Update history
            print("Updating history")
            self.history.append({
                "file_path": file_path,
                "features": features,
                "prediction": pred_label,
                "confidence": confidence
            })
            self.save_history()
            self._update_history_display()
            
        except Exception as e:
            error_msg = f"Error in _process_image: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            self._update_status(f"Error: {str(e)}", "red")
            self.progress_bar.set(0)

    def _display_processed_image(self, processed_image):
        if processed_image is None:
            self._update_status("Error: Processed image is None", "red")
            print("Processed image is None in _display_processed_image")
            return
            
        image_resized = self._resize_image_for_display(processed_image, 400)
        if image_resized is None:
            self._update_status("Error: Failed to resize processed image", "red")
            print("Failed to resize processed image")
            return
            
        pil_image = Image.fromarray(image_resized)
        ctk_image = ctk.CTkImage(light_image=pil_image, dark_image=pil_image, size=(400, 400))
        
        self.proc_image_display.configure(image=ctk_image, text="")
        self.proc_image_display.image = ctk_image
        self.zoom_factor = 1.0
        self.pan_x = 0
        self.pan_y = 0
        self.image_cache.clear()
        print("Processed image displayed")
        self.update()

    def update_features_display(self):
        for widget in self.feature_display.winfo_children():
            widget.destroy()
            
        if self.extracted_features is None:
            ctk.CTkLabel(self.feature_display, text="No features extracted yet",
                        font=ctk.CTkFont(size=12)).pack(pady=5)
            print("No features to display")
            return
            
        print("Updating feature display")
        shape_features = {}
        color_features = {}
        texture_features = {}
        
        for key, value in self.extracted_features.items():
            if key == 'image_name':
                continue
            elif key in ['aspect_ratio', 'compactness', 'contour_complexity']:
                shape_features[key] = value
            elif key.startswith('mean_') or key.startswith('hist_'):
                color_features[key] = value
            elif key in ['contrast'] or key.startswith('lbp_'):
                texture_features[key] = value
                
        if self.shape_var.get() and shape_features:
            self._add_feature_section("Shape Features", shape_features)
            
        if self.color_var.get() and color_features:
            self._add_feature_section("Color Features", color_features)
            self._preview_color_histogram(color_features)
            
        if self.texture_var.get() and texture_features:
            self._add_feature_section("Texture Features", texture_features)
            self._preview_lbp_pattern(texture_features)
        self.update()

    def _preview_color_histogram(self, color_features):
        fig, ax = plt.subplots(figsize=(4, 2))
        for channel in ['r', 'g', 'b']:
            hist = [color_features.get(f'hist_{channel}_{i}', 0) for i in range(8)]
            ax.plot(hist, label=channel.upper())
        ax.set_title("Color Histogram Preview")
        ax.legend()
        canvas = FigureCanvasTkAgg(fig, master=self.feature_display)
        canvas.draw()
        canvas.get_tk_widget().pack(pady=5)
        plt.close(fig)

    def _preview_lbp_pattern(self, texture_features):
        lbp_hist = [texture_features.get(f'lbp_{i}', 0) for i in range(9)]
        fig, ax = plt.subplots(figsize=(4, 2))
        ax.bar(range(9), lbp_hist)
        ax.set_title("LBP Pattern Preview")
        canvas = FigureCanvasTkAgg(fig, master=self.feature_display)
        canvas.draw()
        canvas.get_tk_widget().pack(pady=5)
        plt.close(fig)

    def update_classification_display(self):
        if hasattr(self, 'prediction') and self.prediction is not None:
            pred_label, confidence = self.prediction
            text = f"Predicted Species: {pred_label}\nConfidence: {confidence:.4f}"
            self.classification_label.configure(text=text)
            print(f"Classification updated: {text}")
        else:
            self.classification_label.configure(text="No classification yet")
            print("No classification to display")
        self.update()

    def _add_feature_section(self, title, features):
        header = ctk.CTkLabel(self.feature_display, text=title, 
                             font=ctk.CTkFont(size=14, weight="bold"))
        header.pack(pady=(10, 5), anchor="w")
        
        for key, value in features.items():
            if isinstance(value, (int, float)):
                formatted_value = f"{value:.4f}" if isinstance(value, float) else str(value)
                feature_label = ctk.CTkLabel(self.feature_display, 
                                            text=f"{key}: {formatted_value}",
                                            font=ctk.CTkFont(size=12))
                feature_label.pack(pady=2, padx=5, anchor="w")
                
        separator = ctk.CTkFrame(self.feature_display, height=1, fg_color="gray")
        separator.pack(fill="x", padx=5, pady=5)

    def _update_status(self, message, color="orange"):
        def update():
            self.status_label.configure(text=f"Status: {message}", text_color=color)
            self.update()
        self.after(0, update)

    def batch_process(self):
        folder_path = filedialog.askdirectory(title="Select Folder with Leaf Images")
        if folder_path:
            self.batch_process_folder(folder_path)

    def batch_process_folder(self, folder_path):
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) 
                 if os.path.splitext(f)[1].lower() in image_extensions]
        self.batch_queue.extend(files)
        if not self.batch_processing:
            self.batch_processing = True
            self.pause_btn.configure(state="normal")
            self.resume_btn.configure(state="disabled")
            threading.Thread(target=self._process_batch_queue, daemon=True).start()

    def _process_batch_queue(self):
        output_folder = os.path.join(os.path.dirname(self.batch_queue[0]), 'processed')
        os.makedirs(output_folder, exist_ok=True)
        output_csv = os.path.join(output_folder, 'leaf_features.csv')
        
        all_features = []
        total_files = len(self.batch_queue)
        
        for i, file in enumerate(self.batch_queue[:]):
            if not self.batch_processing:
                break
            if self.batch_paused:
                while self.batch_paused and self.batch_processing:
                    time.sleep(0.1)
                if not self.batch_processing:
                    break
            
            try:
                self._update_status(f"Processing image {i+1}/{total_files}: {os.path.basename(file)}", "orange")
                self.progress_bar.set((i + 1) / total_files)
                output_path = os.path.join(output_folder, os.path.basename(file))
                processed_image = self.improved_leaf_segmentation(file, save_path=output_path, show=False)
                if processed_image is None:
                    print(f"Segmentation failed for {file}")
                    continue
                features = self.extract_features_from_image(processed_image)
                if features:
                    features['image_name'] = os.path.basename(file)
                    pred_label, confidence = self.classify_leaf(features)
                    features['predicted_species'] = pred_label
                    features['confidence'] = confidence
                    all_features.append(features)
                self.batch_queue.pop(0)
            except Exception as e:
                print(f"Error processing {file}: {str(e)}\n{traceback.format_exc()}")
                self._update_status(f"Error processing {file}: {str(e)}", "red")
        
        if all_features:
            pd.DataFrame(all_features).to_csv(output_csv, index=False)
        
        self._update_status("Batch processing complete!", "green")
        self.progress_bar.set(1.0)
        self.batch_processing = False
        self.pause_btn.configure(state="disabled")
        self.resume_btn.configure(state="disabled")
        messagebox.showinfo("Batch Processing Complete", 
                           f"Processed images saved to:\n{output_folder}\n\n"
                           f"Features saved to:\n{output_csv}")

    def pause_batch(self):
        self.batch_paused = True
        self.pause_btn.configure(state="disabled")
        self.resume_btn.configure(state="normal")
        self._update_status("Batch processing paused", "orange")

    def resume_batch(self):
        self.batch_paused = False
        self.pause_btn.configure(state="normal")
        self.resume_btn.configure(state="disabled")
        self._update_status("Batch processing resumed", "green")

    def export_visualization(self):
        if not self.extracted_features:
            messagebox.showerror("Error", "No features to visualize!")
            return
        
        save_dir = filedialog.askdirectory(title="Select Folder to Save Visualizations")
        if not save_dir:
            return
        
        try:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
            
            for channel in ['r', 'g', 'b']:
                hist = [self.extracted_features.get(f'hist_{channel}_{i}', 0) for i in range(8)]
                ax1.plot(hist, label=channel.upper())
            ax1.set_title("Color Histogram")
            ax1.legend()
            
            lbp_hist = [self.extracted_features.get(f'lbp_{i}', 0) for i in range(9)]
            ax2.bar(range(9), lbp_hist)
            ax2.set_title("LBP Pattern")
            
            plt.tight_layout()
            save_path = os.path.join(save_dir, "feature_visualization.png")
            plt.savefig(save_path)
            plt.close(fig)
            
            messagebox.showinfo("Export Successful", f"Visualization saved to:\n{save_path}")
        except Exception as e:
            messagebox.showerror("Export Error", f"Error exporting visualization: {str(e)}")

    def save_results(self):
        if self.processed_image is None or self.extracted_features is None:
            messagebox.showerror("Error", "No results to save!")
            return
            
        save_dir = filedialog.askdirectory(title="Select Folder to Save Results")
        if not save_dir:
            return
            
        try:
            if self.original_image_path:
                base_name = os.path.splitext(os.path.basename(self.original_image_path))[0]
                processed_img_path = os.path.join(save_dir, f"{base_name}_processed.png")
                cv2.imwrite(processed_img_path, 
                           cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR))
                
            features_path = os.path.join(save_dir, f"{base_name}_features.csv")
            pd.DataFrame([self.extracted_features]).to_csv(features_path, index=False)
            
            messagebox.showinfo("Save Successful", 
                               f"Results saved to:\n{save_dir}")
            
        except Exception as e:
            messagebox.showerror("Save Error", f"Error saving results: {str(e)}")

    def load_history(self):
        if not os.path.exists(self.history_file):
            return []
        try:
            with open(self.history_file, "r") as f:
                lines = f.readlines()
            return [eval(line.strip()) for line in lines if line.strip()]
        except:
            return []

    def save_history(self):
        with open(self.history_file, "w") as f:
            for item in self.history:
                f.write(str(item) + "\n")

    def apply_gaussian_blur(self, image, kernel_size=(5, 5), sigma=0):
        if image is None:
            return None
        return cv2.GaussianBlur(image, kernel_size, sigma)

    def improved_leaf_segmentation(self, image_path, save_path=None, show=False):
        print(f"Starting segmentation for {image_path}")
        image_bgr = cv2.imread(image_path)
        if image_bgr is None:
            print(f"Failed to load image: {image_path}")
            return None

        image_bgr = self.apply_gaussian_blur(image_bgr, kernel_size=(5, 5), sigma=1)
        if image_bgr is None:
            print("Gaussian blur returned None")
            return None
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        image_copy = image_rgb.copy()

        hsv = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2HSV)
        lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
        gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)

        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        gray_clahe = clahe.apply(gray)

        lower_green1 = np.array([20, 30, 20])
        upper_green1 = np.array([100, 255, 255])
        mask_green1 = cv2.inRange(hsv, lower_green1, upper_green1)

        lower_green2 = np.array([35, 20, 20])
        upper_green2 = np.array([85, 255, 255])
        mask_green2 = cv2.inRange(hsv, lower_green2, upper_green2)

        mask_green = cv2.bitwise_or(mask_green1, mask_green2)

        lower_purple = np.array([125, 30, 30])
        upper_purple = np.array([150, 255, 255])
        mask_purple = cv2.inRange(hsv, lower_purple, upper_purple)

        lower_yellow = np.array([20, 50, 50])
        upper_yellow = np.array([40, 255, 255])
        mask_yellow = cv2.inRange(hsv, lower_yellow, upper_yellow)

        lower_white = np.array([0, 0, 200])
        upper_white = np.array([180, 40, 255])
        mask_white = cv2.inRange(hsv, lower_white, upper_white)

        mask_hsv = cv2.bitwise_or(mask_green, mask_purple)
        mask_hsv = cv2.bitwise_or(mask_hsv, mask_yellow)
        mask_hsv = cv2.bitwise_or(mask_hsv, mask_white)

        l, a, b = cv2.split(lab)
        _, mask_lab = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        mask_lab = cv2.bitwise_not(mask_lab)

        mask_otsu = cv2.adaptiveThreshold(gray_clahe, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                        cv2.THRESH_BINARY, 11, 2)

        combined_mask = cv2.bitwise_and(mask_hsv, mask_otsu)
        combined_mask = cv2.bitwise_and(combined_mask, mask_lab)

        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9))
        mask_clean = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel, iterations=4)
        mask_clean = cv2.dilate(mask_clean, kernel, iterations=3)
        mask_processed = cv2.morphologyEx(mask_clean, cv2.MORPH_OPEN, kernel, iterations=1)

        contours, _ = cv2.findContours(mask_processed, cv2.RETR_EXTERNAL, 
                                     cv2.CHAIN_APPROX_SIMPLE)
        leaf_mask = np.zeros_like(mask_processed)
        min_area = 1000
        valid_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > min_area]
        if valid_contours:
            cv2.drawContours(leaf_mask, valid_contours, -1, 255, cv2.FILLED)
        else:
            leaf_mask = mask_clean.copy()

        segmented_leaf = cv2.bitwise_and(image_copy, image_copy, mask=leaf_mask)

        coords = cv2.findNonZero(leaf_mask)
        if coords is not None:
            x, y, w, h = cv2.boundingRect(coords)
            cropped_leaf = segmented_leaf[y:y+h, x:x+w]
        else:
            cropped_leaf = segmented_leaf

        target_size = (256, 256)
        h_, w_ = cropped_leaf.shape[:2]
        if w_ == 0 or h_ == 0:
            print("Invalid crop dimensions")
            return None
        scale = min(target_size[0] / w_, target_size[1] / h_)
        new_size = (int(w_ * scale), int(h_ * scale))
        resized_leaf = cv2.resize(cropped_leaf, new_size, interpolation=cv2.INTER_AREA)
        resized = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
        x_offset = (target_size[0] - new_size[0]) // 2
        y_offset = (target_size[1] - new_size[1]) // 2
        resized[y_offset:y_offset+new_size[1], x_offset:x_offset+new_size[0]] = resized_leaf

        sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
        final_image = cv2.filter2D(resized, -1, sharpen_kernel)

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            base_name = os.path.splitext(os.path.basename(image_path))[0]
            save_dir = os.path.dirname(save_path)
            final_image_bgr = cv2.cvtColor(final_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(save_dir, f"{base_name}_final.png"), final_image_bgr)

        if show:
            self._show_processing_steps(image_rgb, mask_hsv, mask_lab, 
                                      mask_otsu, segmented_leaf, final_image)

        print("Segmentation successful")
        return final_image

    def _show_processing_steps(self, original, mask_hsv, mask_lab, mask_otsu, segmented, final):
        steps_window = ctk.CTkToplevel(self)
        steps_window.title("Processing Steps")
        steps_window.geometry("1200x400")
        
        fig, axes = plt.subplots(1, 6, figsize=(18, 3))
        
        axes[0].imshow(original)
        axes[0].set_title("Original")
        axes[1].imshow(mask_hsv, cmap='gray')
        axes[1].set_title("HSV Mask")
        axes[2].imshow(mask_lab, cmap='gray')
        axes[2].set_title("LAB Mask")
        axes[3].imshow(mask_otsu, cmap='gray')
        axes[3].set_title("Otsu/CLAHE Mask")
        axes[4].imshow(segmented)
        axes[4].set_title("Segmented")
        axes[5].imshow(final)
        axes[5].set_title("Final Output")
        
        for ax in axes:
            ax.axis('off')
            
        plt.tight_layout()
        
        canvas = FigureCanvasTkAgg(fig, master=steps_window)
        canvas.draw()
        canvas.get_tk_widget().pack(fill="both", expand=True)

    def process_folder(self, folder_path, output_folder=None):
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        files = [f for f in os.listdir(folder_path) 
                if os.path.splitext(f)[1].lower() in image_extensions]
        total_files = len(files)

        if output_folder is None:
            output_folder = os.path.join(folder_path, 'processed')
        os.makedirs(output_folder, exist_ok=True)

        for i, file in enumerate(files):
            try:
                self._update_status(f"Processing image {i+1}/{total_files}: {file}", "orange")
                self.progress_bar.set((i + 1) / total_files)
                input_path = os.path.join(folder_path, file)
                output_path = os.path.join(output_folder, file)
                self.improved_leaf_segmentation(input_path, save_path=output_path, show=False)
            except Exception as e:
                print(f"Failed to process {file}: {str(e)}")

    def extract_features_from_image(self, image):
        if image is None:
            print("Input image is None in extract_features_from_image")
            return None
            
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
        
        features = self.extract_features(image, mask)
        if features is None:
            print("Feature extraction failed")
        return features

    def extract_features(self, image, mask=None):
        print("Starting feature extraction")
        if len(image.shape) == 2 or image.shape[2] == 1:
            image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        else:
            image_rgb = image.copy()

        if mask is None:
            gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
            _, mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

        _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

        features = {}

        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not contours:
            print("No contours found")
            return None

        contour = max(contours, key=cv2.contourArea)

        x, y, w, h = cv2.boundingRect(contour)
        aspect_ratio = w / h if h > 0 else 0
        features['aspect_ratio'] = aspect_ratio

        area = cv2.contourArea(contour)
        perimeter = cv2.arcLength(contour, True)
        compactness = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
        features['compactness'] = compactness

        approx = cv2.approxPolyDP(contour, 0.01 * perimeter, True)
        contour_complexity = len(approx)
        features['contour_complexity'] = contour_complexity

        masked_image = cv2.bitwise_and(image_rgb, image_rgb, mask=mask)
        mean_rgb = cv2.mean(masked_image, mask=mask)[:3]
        features['mean_r'] = mean_rgb[0]
        features['mean_g'] = mean_rgb[1]
        features['mean_b'] = mean_rgb[2]

        for i, channel in enumerate(['r', 'g', 'b']):
            hist = cv2.calcHist([image_rgb], [i], mask, [8], [0, 256])
            hist = hist.flatten() / hist.sum() if hist.sum() > 0 else hist.flatten()
            for j in range(8):
                features[f'hist_{channel}_{j}'] = hist[j]

        gray = cv2.cvtColor(masked_image, cv2.COLOR_RGB2GRAY)
        contrast = gray.std() if mask.sum() > 0 else 0
        features['contrast'] = contrast

        radius = 3
        n_points = 8 * radius
        lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
        lbp_hist, _ = np.histogram(lbp[mask > 0], bins=9, range=(0, 9), density=True)
        for i in range(9):
            features[f'lbp_{i}'] = lbp_hist[i] if not np.isnan(lbp_hist[i]) else 0

        print("Feature extraction successful")
        return features

    def process_folder_for_features(self, input_folder, output_csv):
        files = [f for f in os.listdir(input_folder) if f.endswith('_final.png')]
        all_features = []
        total_files = len(files)

        for i, file in enumerate(files):
            try:
                self._update_status(f"Extracting features {i+1}/{total_files}: {file}", "orange")
                self.progress_bar.set((i + 1) / total_files)
                image_path = os.path.join(input_folder, file)
                features = self.extract_features(cv2.imread(image_path, cv2.IMREAD_COLOR))
                if features:
                    features['image_name'] = file
                    pred_label, confidence = self.classify_leaf(features)
                    features['predicted_species'] = pred_label
                    features['confidence'] = confidence
                    all_features.append(features)
                else:
                    print(f"No valid contours found in: {file}")
            except Exception as e:
                print(f"Failed to process {file}: {str(e)}")

        if all_features:
            df = pd.DataFrame(all_features)
            os.makedirs(os.path.dirname(output_csv), exist_ok=True)
            df.to_csv(output_csv, index=False)
            print(f"Features saved to: {output_csv}")
        else:
            print("No features extracted.")

    def classify_leaf(self, features):
        print("Starting classification")
        if not features:
            print("No features provided for classification")
            return "Unknown", 0.0
            
        try:
            ar = features.get('aspect_ratio', 0)
            mg = features.get('mean_g', 0)
            
            if ar < 0.5 and mg > 170:
                label = "SpiderPlant"
                confidence = 0.9 if 0.1 <= ar <= 0.2 and 175 <= mg <= 185 else 0.7
            elif 0.4 <= ar < 0.5 and 140 <= mg <= 150:
                label = "PhilodendronRugosum"
                confidence = 0.9 if 0.45 <= ar <= 0.48 and 141 <= mg <= 145 else 0.7
            elif 0.6 <= ar < 0.7 and 160 <= mg <= 165:
                label = "DendrobiumNobile"
                confidence = 0.9 if 0.62 <= ar <= 0.65 and 162 <= mg <= 164 else 0.7
            elif 0.7 <= ar < 0.8 and 150 <= mg <= 155:
                label = "DracaenaSurculosa"
                confidence = 0.9 if 0.73 <= ar <= 0.75 and 153 <= mg <= 155 else 0.7
            elif 0.8 <= ar < 0.9 and 150 <= mg <= 152:
                label = "DwarUmbrellaTree"
                confidence = 0.9 if 0.85 <= ar <= 0.87 and 150.5 <= mg <= 151.5 else 0.7
            elif 0.8 <= ar < 0.9 and 145 <= mg <= 150:
                label = "Philodendron"
                confidence = 0.9 if 0.85 <= ar <= 0.87 and 147 <= mg <= 149 else 0.7
            else:
                label = "Unknown"
                confidence = 0.5
                
            print(f"Classification result: {label}, {confidence}")
            return label, confidence
        except Exception as e:
            print(f"Error classifying leaf: {str(e)}\n{traceback.format_exc()}")
            return "Unknown", 0.0

if ___name___ == "___main___":
    app = LeafClassifierGUI()
    app.mainloop()