In [None]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import seaborn as sns
import threading
import time
import os
from io import BytesIO
import joblib
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor, RandomForestClassifier, ExtraTreesClassifier
from xgboost import XGBRegressor, XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
from boruta import BorutaPy
from sklearn.ensemble import RandomForestClassifier
import matplotlib
matplotlib.use('Agg')
import math

class WaterQualityApp:
    def __init__(self, root):
        self.root = root
        self.root.title("River Benue Water Quality Prediction System")
        self.root.geometry("1200x800")
        self.root.configure(bg='#f0f8ff')
        
        # Initialize data structures
        self.data = None
        self.guidelines = None
        self.selected_features = []
        self.models = {}
        self.current_wqi_method = "WAMWQI"
        self.classification_ranges = [
            (90, 100, "Excellent"),
            (70, 89, "Good"),
            (50, 69, "Medium"),
            (25, 49, "Poor"),
            (0, 24, "Very Poor")
        ]
        
        # Normalization parameters (to be set during preprocessing)
        self.year_min = None
        self.year_max = None
        self.month_mapping = {
            "January": 1, "February": 2, "March": 3, "April": 4,
            "May": 5, "June": 6, "July": 7, "August": 8,
            "September": 9, "October": 10, "November": 11, "December": 12
        }
        
        # Create notebook (tabbed interface)
        self.notebook = ttk.Notebook(root)
        self.notebook.pack(fill='both', expand=True, padx=10, pady=10)
        
        # Create tabs
        self.setup_data_tab()
        self.setup_feature_tab()
        self.setup_subindex_tab()
        self.setup_weight_tab()
        self.setup_wqi_tab()
        self.setup_model_tab()
        self.setup_prediction_tab()
        
        # Status bar
        self.status_var = tk.StringVar()
        self.status_var.set("Ready")
        self.status_bar = tk.Label(root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)
    
    def update_status(self, message):
        self.status_var.set(message)
        self.root.update_idletasks()
    
    # ----------------------------------------------------------------------
    # Data Loading and Preprocessing
    # ----------------------------------------------------------------------
    def setup_data_tab(self):
        self.data_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.data_tab, text="Data Loading")
        
        # Data file section
        ttk.Label(self.data_tab, text="Water Quality Data:").grid(row=0, column=0, padx=10, pady=10, sticky="w")
        self.data_path = tk.StringVar()
        ttk.Entry(self.data_tab, textvariable=self.data_path, width=50).grid(row=0, column=1, padx=10, pady=10)
        ttk.Button(self.data_tab, text="Browse", command=self.load_data).grid(row=0, column=2, padx=10, pady=10)
        
        # Guidelines file section
        ttk.Label(self.data_tab, text="Quality Guidelines:").grid(row=1, column=0, padx=10, pady=10, sticky="w")
        self.guidelines_path = tk.StringVar()
        ttk.Entry(self.data_tab, textvariable=self.guidelines_path, width=50).grid(row=1, column=1, padx=10, pady=10)
        ttk.Button(self.data_tab, text="Browse", command=self.load_guidelines).grid(row=1, column=2, padx=10, pady=10)
        
        # Preview section
        ttk.Label(self.data_tab, text="Data Preview:").grid(row=2, column=0, padx=10, pady=10, sticky="w")
        self.preview_text = scrolledtext.ScrolledText(self.data_tab, width=100, height=15)
        self.preview_text.grid(row=3, column=0, columnspan=3, padx=10, pady=10)
        
        # Process button
        ttk.Button(self.data_tab, text="Preprocess Data", command=self.preprocess_data).grid(row=4, column=1, pady=20)
    
    def load_data(self):
        file_path = filedialog.askopenfilename(
            title="Select Water Quality Data",
            filetypes=(("Excel files", "*.xlsx"), ("CSV files", "*.csv"), ("All files", "*.*"))
        )
        if file_path:
            self.data_path.set(file_path)
            try:
                if file_path.endswith('.xlsx'):
                    self.data = pd.read_excel(file_path)
                else:
                    self.data = pd.read_csv(file_path)
                
                self.preview_text.delete(1.0, tk.END)
                self.preview_text.insert(tk.END, self.data.head().to_string())
                self.update_status("Data loaded successfully")
            except Exception as e:
                messagebox.showerror("Error", f"Failed to load data: {str(e)}")
    
    def load_guidelines(self):
        file_path = filedialog.askopenfilename(
            title="Select Water Quality Guidelines",
            filetypes=(("CSV files", "*.csv"), ("All files", "*.*"))
        )
        if file_path:
            self.guidelines_path.set(file_path)
            try:
                self.guidelines = pd.read_csv(file_path, skiprows=2, header=None, 
                                             names=['Parameter', 'Unit', 'Lower', 'Upper'])
                self.guidelines['Parameter'] = self.guidelines['Parameter'].str.strip().str.upper()
                self.update_status("Guidelines loaded successfully")
            except Exception as e:
                messagebox.showerror("Error", f"Failed to load guidelines: {str(e)}")
    
    def month_to_numeric(self, month_name):
        """Convert month name to number (1-12)."""
        return self.month_mapping.get(month_name, 1)
    
    def normalize_year(self, year):
        """Normalize year using min-max scaling based on training data."""
        if self.year_min is not None and self.year_max is not None:
            return (year - self.year_min) / (self.year_max - self.year_min)
        else:
            return 0.5  # default if not available
    
    def normalize_month(self, month_name):
        """Normalize month using cyclical encoding."""
        month_num = self.month_to_numeric(month_name)
        month_sin = math.sin(2 * math.pi * month_num / 12)
        month_cos = math.cos(2 * math.pi * month_num / 12)
        return month_sin, month_cos
    
    def normalize_season(self, season):
        """Encode season as binary: Wet=1, Dry=0."""
        return 1 if season == "Wet" else 0
    
    def preprocess_data(self):
        if self.data is None:
            messagebox.showwarning("Warning", "Please load water quality data first")
            return
        
        try:
            # Convert numerical columns
            num_cols = ['pH', 'TDS', 'EC', 'TUR', 'DOX', 'NO3', 'PO4', 'COD', 'TC', 
                        'Pb', 'Cd', 'Cu', 'Fe', 'PS']
            for col in num_cols:
                if col in self.data.columns:
                    self.data[col] = pd.to_numeric(self.data[col], errors='coerce')
            
            # Handle missing values
            self.data.fillna(self.data.median(numeric_only=True), inplace=True)
            
            # Store normalization parameters from the data
            if 'Year' in self.data.columns:
                self.year_min = self.data['Year'].min()
                self.year_max = self.data['Year'].max()
                self.data['Year_Normalized'] = self.data['Year'].apply(self.normalize_year)
            else:
                self.data['Year_Normalized'] = 0.5
            
            # Add cyclical month features
            if 'Month' in self.data.columns:
                month_sin_cos = self.data['Month'].apply(
                    lambda m: pd.Series(self.normalize_month(m), index=['Month_Sin', 'Month_Cos'])
                )
                self.data = pd.concat([self.data, month_sin_cos], axis=1)
            else:
                self.data['Month_Sin'] = 0
                self.data['Month_Cos'] = 0
            
            # Encode season
            if 'Season' in self.data.columns:
                self.data['Season_Encoded'] = self.data['Season'].apply(self.normalize_season)
            else:
                self.data['Season_Encoded'] = 0
            
            self.update_status("Data preprocessed and normalized successfully")
            messagebox.showinfo("Success", "Data preprocessing completed.")
        except Exception as e:
            messagebox.showerror("Error", f"Data preprocessing failed: {str(e)}")
    
    # ----------------------------------------------------------------------
    # Feature Selection (Boruta)
    # ----------------------------------------------------------------------
    def setup_feature_tab(self):
        self.feature_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.feature_tab, text="Feature Selection")
        
        # Parameters frame
        param_frame = ttk.LabelFrame(self.feature_tab, text="Boruta Parameters")
        param_frame.pack(fill=tk.X, padx=10, pady=10)
        
        ttk.Label(param_frame, text="Max Iterations:").grid(row=0, column=0, padx=5, pady=5)
        self.max_iter = tk.IntVar(value=100)
        ttk.Entry(param_frame, textvariable=self.max_iter, width=10).grid(row=0, column=1, padx=5, pady=5)
        
        ttk.Label(param_frame, text="Tree Depth:").grid(row=0, column=2, padx=5, pady=5)
        self.max_depth = tk.IntVar(value=5)
        ttk.Entry(param_frame, textvariable=self.max_depth, width=10).grid(row=0, column=3, padx=5, pady=5)
        
        ttk.Label(param_frame, text="Include Tentative:").grid(row=0, column=4, padx=5, pady=5)
        self.include_tentative = tk.BooleanVar(value=False)
        ttk.Checkbutton(param_frame, variable=self.include_tentative).grid(row=0, column=5, padx=5, pady=5)
        
        # Run button
        ttk.Button(param_frame, text="Run Feature Selection", command=self.run_feature_selection).grid(
            row=0, column=6, padx=10, pady=5)
        
        # Results frame
        results_frame = ttk.LabelFrame(self.feature_tab, text="Feature Selection Results")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Treeview for results
        columns = ("Parameter", "Rank")
        self.feature_tree = ttk.Treeview(results_frame, columns=columns, show="headings")
        for col in columns:
            self.feature_tree.heading(col, text=col)
            self.feature_tree.column(col, width=150)
        self.feature_tree.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Feature importance plot
        self.feature_fig, self.feature_ax = plt.subplots(figsize=(8, 4))
        self.feature_canvas = FigureCanvasTkAgg(self.feature_fig, master=results_frame)
        self.feature_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
    
    def run_feature_selection(self):
        if self.data is None:
            messagebox.showwarning("Warning", "Please load and preprocess data first")
            return
        
        try:
            self.update_status("Running Boruta feature selection...")
            
            # Get Boruta parameters
            max_iter = self.max_iter.get()
            max_depth = self.max_depth.get()
            include_tentative = self.include_tentative.get()
            
            # Define features and target
            water_quality_vars = ['pH', 'TDS', 'EC', 'TUR', 'DOX', 'NO3', 'PO4', 
                                 'COD', 'TC', 'Pb', 'Cd', 'Cu', 'Fe']
            # Include normalized temporal features
            temporal_vars = ['Year_Normalized', 'Month_Sin', 'Month_Cos', 'Season_Encoded']
            all_features = water_quality_vars + temporal_vars
            X = self.data[[col for col in all_features if col in self.data.columns]]
            y = self.data['PS']
            
            # Split data
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42, stratify=y
            )
            
            # Initialize Random Forest
            rf = RandomForestClassifier(
                n_jobs=-1, 
                class_weight='balanced',
                max_depth=max_depth,
                random_state=42
            )
            
            # Initialize Boruta
            boruta_selector = BorutaPy(
                estimator=rf,
                n_estimators='auto',
                max_iter=max_iter,
                random_state=42,
                verbose=0
            )
            
            # Run Boruta
            boruta_selector.fit(np.array(X_train), np.array(y_train))
            
            # Get selected features
            self.selected_features = [X.columns[i] for i, selected in enumerate(boruta_selector.support_) if selected]
            if include_tentative:
                tentative_features = [X.columns[i] for i, tentative in enumerate(boruta_selector.support_weak_) if tentative]
                self.selected_features += tentative_features
            
            # Update treeview
            self.feature_tree.delete(*self.feature_tree.get_children())
            for i, feature in enumerate(self.selected_features):
                self.feature_tree.insert("", "end", values=(feature, i+1))
            
            # Plot feature importance
            self.feature_ax.clear()
            feature_ranks = list(zip(X.columns, boruta_selector.ranking_))
            feature_ranks.sort(key=lambda x: x[1])
            features_sorted, ranks_sorted = zip(*feature_ranks)
            
            self.feature_ax.barh(features_sorted, ranks_sorted, color='skyblue')
            self.feature_ax.axvline(x=1, color='red', linestyle='--', label='Confirmation Threshold')
            self.feature_ax.set_xlabel('Boruta Ranking (1 = confirmed)')
            self.feature_ax.set_title('Water Quality Variable Importance Ranking')
            self.feature_ax.invert_yaxis()
            self.feature_ax.legend()
            self.feature_canvas.draw()
            
            self.update_status("Feature selection completed")
            messagebox.showinfo("Success", f"Selected {len(self.selected_features)} features.")
        except Exception as e:
            messagebox.showerror("Error", f"Feature selection failed: {str(e)}")
    
    # ----------------------------------------------------------------------
    # Sub-index Calculation (placeholder)
    # ----------------------------------------------------------------------
    def setup_subindex_tab(self):
        self.subindex_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.subindex_tab, text="Sub-index Calculation")
        
        # Calculate button
        ttk.Button(self.subindex_tab, text="Calculate Sub-indices", command=self.calculate_subindices).pack(pady=10)
        
        # Results frame
        results_frame = ttk.LabelFrame(self.subindex_tab, text="Sub-index Results")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Treeview for results
        columns = ["Year", "Month", "Season"] + [f"{p}_SI" for p in ['pH', 'TUR', 'DOX', 'NO3', 'PO4', 'COD', 'TC']]
        self.subindex_tree = ttk.Treeview(results_frame, columns=columns, show="headings")
        for col in columns:
            self.subindex_tree.heading(col, text=col)
            self.subindex_tree.column(col, width=80)
        self.subindex_tree.pack(fill=tk.BOTH, expand=True, side=tk.LEFT, padx=10, pady=10)
        
        # Plot frame
        plot_frame = ttk.Frame(results_frame)
        plot_frame.pack(fill=tk.BOTH, expand=True, side=tk.RIGHT, padx=10, pady=10)
        self.subindex_fig, self.subindex_ax = plt.subplots(figsize=(8, 6))
        self.subindex_canvas = FigureCanvasTkAgg(self.subindex_fig, master=plot_frame)
        self.subindex_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        ttk.Button(plot_frame, text="Show Seasonal Comparison", command=self.show_seasonal_subindices).pack(pady=10)
    
    def calculate_subindices(self):
        # Placeholder: In real implementation, compute sub-indices using guidelines and equations.
        self.update_status("Sub-indices calculated successfully")
        
        # Update treeview with sample data
        self.subindex_tree.delete(*self.subindex_tree.get_children())
        for i in range(5):
            self.subindex_tree.insert("", "end", values=(
                2022, "Jan", "Dry", 85.2, 92.4, 78.3, 65.7, 88.9, 72.1, 95.6
            ))
    
    # ----------------------------------------------------------------------
    # Weight Calculation (placeholder)
    # ----------------------------------------------------------------------
    def setup_weight_tab(self):
        self.weight_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.weight_tab, text="Weight Calculation")
        
        # Calculate button
        ttk.Button(self.weight_tab, text="Calculate Weights", command=self.calculate_weights).pack(pady=10)
        
        # Results frame
        results_frame = ttk.LabelFrame(self.weight_tab, text="Weight Results")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Treeview for weights
        columns = ("Parameter", "Dry Weight", "Wet Weight")
        self.weight_tree = ttk.Treeview(results_frame, columns=columns, show="headings")
        for col in columns:
            self.weight_tree.heading(col, text=col)
            self.weight_tree.column(col, width=150)
        self.weight_tree.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Weight comparison plot
        self.weight_fig, self.weight_ax = plt.subplots(figsize=(8, 4))
        self.weight_canvas = FigureCanvasTkAgg(self.weight_fig, master=results_frame)
        self.weight_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
    
    def calculate_weights(self):
        self.update_status("Weights calculated successfully")
        
        # Update treeview with sample data
        self.weight_tree.delete(*self.weight_tree.get_children())
        parameters = ['pH', 'TUR', 'DOX', 'NO3', 'PO4', 'COD', 'TC']
        for param in parameters:
            self.weight_tree.insert("", "end", values=(
                param, f"{np.random.uniform(0.1, 0.2):.4f}", f"{np.random.uniform(0.1, 0.2):.4f}"
            ))
    
    # ----------------------------------------------------------------------
    # WQI Calculation (placeholder)
    # ----------------------------------------------------------------------
    def setup_wqi_tab(self):
        self.wqi_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.wqi_tab, text="WQI Calculation")
        
        # Method selection
        method_frame = ttk.Frame(self.wqi_tab)
        method_frame.pack(fill=tk.X, padx=10, pady=10)
        
        ttk.Label(method_frame, text="Select WQI Method:").pack(side=tk.LEFT, padx=5)
        self.wqi_method = tk.StringVar(value="WAMWQI")
        methods = ["WAMWQI", "WQMWQI", "DWQI", "RMSWQI", "AMWQI", "FLWQI"]
        ttk.Combobox(method_frame, textvariable=self.wqi_method, values=methods, width=10).pack(side=tk.LEFT, padx=5)
        ttk.Button(method_frame, text="Calculate WQI", command=self.calculate_wqis).pack(side=tk.LEFT, padx=10)
        
        # Results frame
        results_frame = ttk.LabelFrame(self.wqi_tab, text="WQI Results")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Treeview for WQI results
        columns = ["Year", "Month", "Season", "WQI_Value", "WQI_Class"]
        self.wqi_tree = ttk.Treeview(results_frame, columns=columns, show="headings")
        for col in columns:
            self.wqi_tree.heading(col, text=col)
            self.wqi_tree.column(col, width=100)
        self.wqi_tree.pack(fill=tk.BOTH, expand=True, side=tk.LEFT, padx=10, pady=10)
        
        # Plot frame
        plot_frame = ttk.Frame(results_frame)
        plot_frame.pack(fill=tk.BOTH, expand=True, side=tk.RIGHT, padx=10, pady=10)
        self.wqi_fig, self.wqi_ax = plt.subplots(figsize=(8, 6))
        self.wqi_canvas = FigureCanvasTkAgg(self.wqi_fig, master=plot_frame)
        self.wqi_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        ttk.Button(plot_frame, text="Show WQI Trend", command=self.show_wqi_trend).pack(pady=10)
    
    def calculate_wqis(self):
        method = self.wqi_method.get()
        self.update_status(f"WQI calculated using {method} method")
        
        # Update treeview with sample data
        self.wqi_tree.delete(*self.wqi_tree.get_children())
        for i in range(5):
            self.wqi_tree.insert("", "end", values=(
                2022, ["Jan","Feb","Mar","Apr","May"][i], ["Dry","Dry","Dry","Dry","Wet"][i],
                np.random.uniform(70, 95), ["Good","Good","Good","Good","Excellent"][i]
            ))
    
    # ----------------------------------------------------------------------
    # Model Training (placeholder)
    # ----------------------------------------------------------------------
    def setup_model_tab(self):
        self.model_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.model_tab, text="Model Training")
        
        # Model selection
        model_frame = ttk.Frame(self.model_tab)
        model_frame.pack(fill=tk.X, padx=10, pady=10)
        
        ttk.Label(model_frame, text="Select WQI Method:").pack(side=tk.LEFT, padx=5)
        self.model_wqi_method = tk.StringVar(value="WAMWQI")
        methods = ["WAMWQI", "WQMWQI", "DWQI", "RMSWQI", "AMWQI", "FLWQI"]
        ttk.Combobox(model_frame, textvariable=self.model_wqi_method, values=methods, width=10).pack(side=tk.LEFT, padx=5)
        
        ttk.Label(model_frame, text="Model Type:").pack(side=tk.LEFT, padx=5)
        self.model_type = tk.StringVar(value="Regression")
        ttk.Combobox(model_frame, textvariable=self.model_type, values=["Regression", "Classification"], width=12).pack(side=tk.LEFT, padx=5)
        
        ttk.Button(model_frame, text="Train Models", command=self.train_models).pack(side=tk.LEFT, padx=10)
        
        # Results frame
        results_frame = ttk.LabelFrame(self.model_tab, text="Model Performance")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Performance metrics
        metrics_frame = ttk.Frame(results_frame)
        metrics_frame.pack(fill=tk.X, padx=10, pady=10)
        
        ttk.Label(metrics_frame, text="Model").grid(row=0, column=0, padx=5, pady=5)
        ttk.Label(metrics_frame, text="RMSE").grid(row=0, column=1, padx=5, pady=5)
        ttk.Label(metrics_frame, text="RÂ²").grid(row=0, column=2, padx=5, pady=5)
        ttk.Label(metrics_frame, text="Accuracy").grid(row=0, column=3, padx=5, pady=5)
        ttk.Label(metrics_frame, text="F1 Score").grid(row=0, column=4, padx=5, pady=5)
        
        self.model_metrics = {}
        for i, model in enumerate(["RF", "ETR", "XGB"]):
            ttk.Label(metrics_frame, text=model).grid(row=i+1, column=0, padx=5, pady=5)
            for j, metric in enumerate(["rmse", "r2", "accuracy", "f1"]):
                var = tk.StringVar(value="")
                ttk.Label(metrics_frame, textvariable=var).grid(row=i+1, column=j+1, padx=5, pady=5)
                self.model_metrics[(model, metric)] = var
        
        # Model visualization
        self.model_fig, self.model_ax = plt.subplots(figsize=(10, 6))
        self.model_canvas = FigureCanvasTkAgg(self.model_fig, master=results_frame)
        self.model_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
    
    def train_models(self):
        method = self.model_wqi_method.get()
        model_type = self.model_type.get()
        self.update_status(f"Training {model_type} models for {method} WQI")
        
        # Update metrics with sample data
        for model in ["RF", "ETR", "XGB"]:
            self.model_metrics[(model, "rmse")].set(f"{np.random.uniform(2.5, 5.0):.3f}")
            self.model_metrics[(model, "r2")].set(f"{np.random.uniform(0.85, 0.95):.3f}")
            self.model_metrics[(model, "accuracy")].set(f"{np.random.uniform(0.88, 0.96):.3f}")
            self.model_metrics[(model, "f1")].set(f"{np.random.uniform(0.87, 0.95):.3f}")
    
    # ----------------------------------------------------------------------
    # Prediction Tab
    # ----------------------------------------------------------------------
    def setup_prediction_tab(self):
        self.prediction_tab = ttk.Frame(self.notebook)
        self.notebook.add(self.prediction_tab, text="Prediction")
        
        # Input frame
        input_frame = ttk.LabelFrame(self.prediction_tab, text="Input Parameters")
        input_frame.pack(fill=tk.X, padx=10, pady=10)
        
        # Create input fields for each parameter
        self.input_vars = {}
        parameters = ["pH", "TUR", "DOX", "NO3", "PO4", "COD", "TC"]
        for i, param in enumerate(parameters):
            ttk.Label(input_frame, text=f"{param}:").grid(row=i//4, column=(i%4)*2, padx=5, pady=5, sticky="e")
            var = tk.DoubleVar()
            ttk.Entry(input_frame, textvariable=var, width=10).grid(row=i//4, column=(i%4)*2+1, padx=5, pady=5)
            self.input_vars[param] = var
        
        # Month selection
        ttk.Label(input_frame, text="Month:").grid(row=2, column=6, padx=5, pady=5, sticky="e")
        self.month_var = tk.StringVar(value="January")
        months = ["January", "February", "March", "April", "May", "June", 
                 "July", "August", "September", "October", "November", "December"]
        ttk.Combobox(input_frame, textvariable=self.month_var, values=months, width=10).grid(row=2, column=7, padx=5, pady=5)
        
        # Season selection
        ttk.Label(input_frame, text="Season:").grid(row=2, column=8, padx=5, pady=5, sticky="e")
        self.season_var = tk.StringVar(value="Dry")
        ttk.Combobox(input_frame, textvariable=self.season_var, values=["Dry", "Wet"], width=8).grid(row=2, column=9, padx=5, pady=5)
        
        # Year input
        ttk.Label(input_frame, text="Year:").grid(row=3, column=6, padx=5, pady=5, sticky="e")
        self.year_var = tk.IntVar(value=2025)
        ttk.Entry(input_frame, textvariable=self.year_var, width=10).grid(row=3, column=7, padx=5, pady=5)
        
        # Prediction button
        ttk.Button(input_frame, text="Predict WQI", command=self.predict_wqi).grid(row=3, column=8, columnspan=2, padx=10, pady=10)
        
        # Prediction results
        results_frame = ttk.LabelFrame(self.prediction_tab, text="Prediction Results")
        results_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # WQI value
        ttk.Label(results_frame, text="Predicted WQI Value:").pack(pady=5)
        self.wqi_value_var = tk.StringVar(value="")
        ttk.Label(results_frame, textvariable=self.wqi_value_var, font=("Arial", 16)).pack(pady=5)
        
        # WQI class
        ttk.Label(results_frame, text="Water Quality Class:").pack(pady=5)
        self.wqi_class_var = tk.StringVar(value="")
        ttk.Label(results_frame, textvariable=self.wqi_class_var, font=("Arial", 16)).pack(pady=5)
        
        # Monthly prediction
        ttk.Button(results_frame, text="Predict Monthly WQI", command=self.predict_monthly_wqi).pack(pady=10)
        
        # Monthly prediction results
        columns = ["Month", "Season", "Predicted WQI", "Quality Class"]
        self.monthly_tree = ttk.Treeview(results_frame, columns=columns, show="headings")
        for col in columns:
            self.monthly_tree.heading(col, text=col)
            self.monthly_tree.column(col, width=120)
        self.monthly_tree.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Save prediction button
        ttk.Button(results_frame, text="Save Predictions to CSV", command=self.save_predictions).pack(pady=10)
    
    def predict_wqi(self):
        try:
            # Collect input values
            input_data = {}
            for param, var in self.input_vars.items():
                input_data[param] = var.get()
            
            month = self.month_var.get()
            season = self.season_var.get()
            year = self.year_var.get()
            
            # Normalize inputs (use stored min/max from training data)
            year_norm = self.normalize_year(year)
            month_sin, month_cos = self.normalize_month(month)
            season_enc = self.normalize_season(season)
            
            # Placeholder: In a real system, you would use trained models to predict.
            # Here we generate a random WQI value for demonstration.
            wqi_value = np.random.uniform(60, 95)
            
            # Determine class
            wqi_class = "Unknown"
            for low, high, label in self.classification_ranges:
                if low <= wqi_value <= high:
                    wqi_class = label
                    break
            
            self.wqi_value_var.set(f"{wqi_value:.2f}")
            self.wqi_class_var.set(wqi_class)
            self.update_status("WQI predicted successfully")
        except Exception as e:
            messagebox.showerror("Error", f"Prediction failed: {str(e)}")
    
    def predict_monthly_wqi(self):
        self.monthly_tree.delete(*self.monthly_tree.get_children())
        
        months = ["January", "February", "March", "April", "May", "June", 
                 "July", "August", "September", "October", "November", "December"]
        # Typical season assignment for River Benue (adjust as needed)
        seasons = ["Dry", "Dry", "Dry", "Dry", "Wet", "Wet", "Wet", "Wet", "Wet", "Dry", "Dry", "Dry"]
        
        for i, month in enumerate(months):
            season = seasons[i]
            wqi_value = np.random.uniform(60, 95)
            
            # Determine class
            wqi_class = "Unknown"
            for low, high, label in self.classification_ranges:
                if low <= wqi_value <= high:
                    wqi_class = label
                    break
            
            self.monthly_tree.insert("", "end", values=(
                month, season, f"{wqi_value:.2f}", wqi_class
            ))
        
        self.update_status("Monthly WQI predictions generated")
    
    def save_predictions(self):
        file_path = filedialog.asksaveasfilename(
            title="Save Predictions",
            defaultextension=".csv",
            filetypes=(("CSV files", "*.csv"), ("All files", "*.*"))
        )
        if file_path:
            try:
                with open(file_path, "w") as f:
                    f.write("Month,Season,Predicted WQI,Quality Class\n")
                    for item in self.monthly_tree.get_children():
                        values = self.monthly_tree.item(item, "values")
                        f.write(",".join(values) + "\n")
                self.update_status(f"Predictions saved to {file_path}")
            except Exception as e:
                messagebox.showerror("Error", f"Failed to save: {str(e)}")
    
    # ----------------------------------------------------------------------
    # Visualization helpers
    # ----------------------------------------------------------------------
    def show_seasonal_subindices(self):
        self.subindex_ax.clear()
        parameters = ['pH', 'TUR', 'DOX', 'NO3', 'PO4', 'COD', 'TC']
        dry_means = np.random.uniform(70, 95, len(parameters))
        wet_means = np.random.uniform(65, 90, len(parameters))
        
        x = np.arange(len(parameters))
        width = 0.35
        
        self.subindex_ax.bar(x - width/2, dry_means, width, label='Dry Season', color='sandybrown')
        self.subindex_ax.bar(x + width/2, wet_means, width, label='Wet Season', color='lightseagreen')
        
        self.subindex_ax.set_xlabel('Water Quality Parameters')
        self.subindex_ax.set_ylabel('Sub-index Value')
        self.subindex_ax.set_title('Seasonal Comparison of Sub-indices')
        self.subindex_ax.set_xticks(x)
        self.subindex_ax.set_xticklabels(parameters)
        self.subindex_ax.legend()
        self.subindex_ax.grid(axis='y', alpha=0.3)
        self.subindex_canvas.draw()
    
    def show_wqi_trend(self):
        self.wqi_ax.clear()
        months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
        wqi_values = np.random.uniform(70, 95, len(months))
        
        self.wqi_ax.plot(months, wqi_values, marker='o', color='blue')
        self.wqi_ax.axhline(y=90, color='green', linestyle='--', label='Excellent')
        self.wqi_ax.axhline(y=70, color='orange', linestyle='--', label='Good')
        self.wqi_ax.axhline(y=50, color='red', linestyle='--', label='Medium')
        
        self.wqi_ax.set_xlabel('Month')
        self.wqi_ax.set_ylabel('WQI Value')
        self.wqi_ax.set_title('Monthly WQI Trend')
        self.wqi_ax.legend()
        self.wqi_ax.grid(True)
        self.wqi_canvas.draw()

if __name__ == "__main__":
    root = tk.Tk()
    app = WaterQualityApp(root)
    root.mainloop()