In [35]:
#-----Import Necessary libraries and Packages-----------#
import os
import io
import yaml
import joblib
import optuna
import logging
import datetime
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
from arch import unitroot
from scipy.stats import zscore
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.dates as mdates
import plotly.graph_objects as go
from ipywidgets import FileUpload
from sklearn.cluster import KMeans
from IPython.display import FileLink
import matplotlib.patches as mpatches
from arch.unitroot import PhillipsPerron
from plotly.subplots import make_subplots
from tensorflow.keras.models import Sequential
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.losses import MeanSquaredError, Huber
from statsmodels.tsa.stattools import adfuller, kpss, zivot_andrews
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from tensorflow.keras.layers import Input, LSTM, Conv1D, MaxPooling1D, Flatten, Dense, Dropout, ConvLSTM2D, BatchNormalization

#--------------------Predefined Workflow Functions-----------#
def configure_logging(level=logging.INFO):
    logging.basicConfig(level=level, format="%(asctime)s - %(levelname)s - %(message)s")
    logging.info("Logging configured successfully.")

def load_dataset(file_path):
    """
    Load the dataset from a specified file path.

    Parameters:
    file_path (str): Path to the dataset file.

    Returns:
    pd.DataFrame: Loaded Dataset.
    """
    try:
        # Validate file path
        if not isinstance(file_path, str) or not os.path.exists(file_path):
            raise FileNotFoundError(f"Invalid file path: {file_path}")

        # Attempt to load the dataset
        df = pd.read_csv(file_path, encoding='utf-8')  # Use utf-8 by default
        if df.empty:
            raise ValueError("The dataset is empty.")
        
        # Log successful load
        logging.info(f"Dataset successfully loaded from: {file_path}")
        return df

    except FileNotFoundError:
        logging.error(f"File not found at: {file_path}")
        raise
    except pd.errors.ParserError:
        logging.error("Failed to parse the dataset file. Ensure it is in CSV format.")
        raise
    except UnicodeDecodeError:
        logging.error("Encoding error: Ensure the dataset is UTF-8 encoded.")
        raise
    except Exception as e:
        logging.error(f"Unexpected error while loading the dataset: {str(e)}")
        raise

def handle_missing_values(df, drop_cells=False, drop_rows=False, threshold=0.5):
    """
    Handle missing and redundant values in the dataset.

    Parameters:
    df (pd.DataFrame): Input dataset.
    drop_cells (bool): Drop individual cells with missing values.
    drop_rows (bool): Drop rows with missing values based on a threshold.
    threshold (float): Fraction of missing values in a row to trigger row removal.

    Returns:
    pd.DataFrame: Updated dataset with missing values handled.
    """
    try:
        # Validate and clean YEAR and Month columns
        df = validate_year_month(df)

        # Convert YEAR and Month to datetime
        try:
            df['Month'] = pd.to_datetime(df[['YEAR', 'Month']].assign(day=1), errors='coerce')
            if df['Month'].isna().any():
                raise ValueError("Some values in 'YEAR' and 'Month' could not be converted to datetime.")
        except Exception as e:
            logging.error(f"Error converting YEAR and Month to datetime: {e}")
            raise ValueError("YEAR or Month columns may still contain invalid data after cleaning.")

        # Handle missing cells
        if drop_cells:
            missing_cells_count = df.isna().sum().sum()
            logging.info(f"Dropping columns with any missing values. Total missing cells: {missing_cells_count}")
            df = df.dropna(axis=1, how='any')

        # Handle rows with excessive missing values
        if drop_rows:
            missing_rows_count = df.isna().any(axis=1).sum()
            row_threshold = int(threshold * len(df.columns))
            logging.info(f"Dropping rows with more than {threshold*100}% missing values. Total rows affected: {missing_rows_count}")
            df = df.dropna(thresh=row_threshold)

        logging.info(f"Final dataset has {df.isna().sum().sum()} missing cells.")
        return df

    except Exception as e:
        logging.error(f"Unexpected error in handle_missing_values: {e}")
        raise

def extract_features_and_metrics(
    df,
    daily_columns_prefix="",
    rolling_window=12,
    validate_values=True,
    negative_threshold=0,
    log_level=logging.INFO
):
    """
    Extract features and compute metrics like 'Monthly_Total' and 'Monthly_Average'.

    Parameters:
    - df (pd.DataFrame): Input dataset.
    - daily_columns_prefix (str): Prefix for daily rainfall columns (default is "").
    - rolling_window (int): Rolling average window size.
    - validate_values (bool): Validate rainfall values for anomalies.
    - negative_threshold (float): Threshold for invalid values.
    - log_level (int): Logging level.

    Returns:
    - pd.DataFrame: Updated dataset with metrics.
    - dict: Metadata about processing.
    - list: List of identified daily rainfall columns.
    """
    logging.basicConfig(level=log_level)
    
    try:
        logging.info("Starting feature extraction and metric computation.")

        # Identify daily columns dynamically
        if daily_columns_prefix:
            daily_columns = [col for col in df.columns if col.startswith(daily_columns_prefix)]
        else:
            # Fallback: Identify numeric columns that are not explicitly metadata
            daily_columns = [col for col in df.columns if pd.api.types.is_numeric_dtype(df[col]) and col not in ['YEAR', 'Month']]

        if not daily_columns:
            raise KeyError("No daily rainfall columns found.")

        # Filter numeric daily columns
        numeric_daily_columns = [col for col in daily_columns if pd.api.types.is_numeric_dtype(df[col])]
        non_numeric_columns = set(daily_columns) - set(numeric_daily_columns)

        if non_numeric_columns:
            logging.warning(f"Non-numeric data found in columns: {list(non_numeric_columns)}")

        # Optional value validation for numeric columns
        if validate_values:
            invalid_values = df[numeric_daily_columns][df[numeric_daily_columns] < negative_threshold]
            if not invalid_values.empty:
                logging.warning(f"Detected invalid values below {negative_threshold}: {invalid_values.count().sum()} entries.")

        # Compute summary metrics for numeric columns
        df['Monthly_Total'] = df[numeric_daily_columns].sum(axis=1, skipna=True)
        df['Monthly_Average'] = df[numeric_daily_columns].mean(axis=1, skipna=True)
        df['Rolling_Average'] = df['Monthly_Total'].rolling(window=rolling_window, min_periods=1).mean()

        # Extract additional time-based features
        df['Year'] = df['Month'].dt.year
        df['Month_Name'] = df['Month'].dt.strftime('%B')

        logging.info("Feature extraction and metric computation completed successfully.")

        # Return metadata for observability
        metadata = {
            "daily_columns": len(daily_columns),
            "numeric_columns": len(numeric_daily_columns),
            "non_numeric_columns": len(non_numeric_columns),
            "invalid_values": invalid_values.count().sum() if validate_values else 0,
        }
        return df, metadata, daily_columns

    except KeyError as e:
        logging.error(f"KeyError in extract_features_and_metrics: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in extract_features_and_metrics: {e}")
        raise

def drop_columns(df, columns_to_drop=None, include_missing_log=True):
    """
    Drop specified columns from a DataFrame.

    Parameters:
    df (pd.DataFrame): Input DataFrame.
    columns_to_drop (list or None): List of columns to drop. If None or empty, no columns are dropped.
    include_missing_log (bool): Whether to log columns that are not found in the DataFrame.

    Returns:
    pd.DataFrame: Updated DataFrame after removing the specified columns.

    Raises:
    ValueError: If the input `df` is not a DataFrame.
    """
    try:
        # Validate input DataFrame
        if not isinstance(df, pd.DataFrame):
            raise ValueError("Input `df` must be a pandas DataFrame.")

        # Handle None or empty column list
        if not columns_to_drop:
            logging.info("No columns specified for dropping. Returning the original DataFrame.")
            return df

        # Identify columns to actually drop
        columns_present = [col for col in columns_to_drop if col in df.columns]
        columns_missing = [col for col in columns_to_drop if col not in df.columns]

        # Log missing columns if enabled
        if include_missing_log and columns_missing:
            logging.warning(f"The following columns were not found in the dataset: {columns_missing}")

        # Drop columns that exist in the DataFrame
        if columns_present:
            logging.info(f"Removing the following columns: {columns_present}")
            df = df.drop(columns=columns_present)

        return df

    except ValueError as ve:
        logging.error(f"ValueError in drop_columns: {ve}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in drop_columns: {e}")
        raise

def validate_data(df):
    """
    Validate the dataset's integrity and check for anomalies.

    Actions:
    - Check for negative values in 'Monthly_Total' and 'Monthly_Average'.
    - Log the counts of negative values for each column.
    - Perform basic data integrity checks (e.g., missing columns, NaNs, infinite values).

    Parameters:
    df (pd.DataFrame): Input dataset.

    Returns:
    dict: Validation summary containing counts of negative values, missing values, and other anomalies.

    Raises:
    ValueError: If required columns are missing or other critical issues are found.
    """
    try:
        # Validate input
        if not isinstance(df, pd.DataFrame):
            raise ValueError("Input `df` must be a pandas DataFrame.")

        # Ensure required columns exist
        required_columns = ['Monthly_Total', 'Monthly_Average']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise KeyError(f"Missing required columns: {missing_columns}")

        # Initialize validation summary
        validation_summary = {}

        # Check for negative values
        validation_summary['Negative Monthly_Total Values'] = df[df['Monthly_Total'] < 0].shape[0]
        validation_summary['Negative Monthly_Average Values'] = df[df['Monthly_Average'] < 0].shape[0]

        # Check for missing values
        validation_summary['Missing Values'] = df.isnull().sum().sum()

        # Check for infinite values in numeric columns only
        numeric_df = df.select_dtypes(include=[np.number])
        validation_summary['Infinite Values'] = np.isinf(numeric_df.values).sum()

        # Log validation summary
        logging.info("Dataset Validation Summary:")
        for key, value in validation_summary.items():
            logging.info(f"{key}: {value}")

        return validation_summary

    except KeyError as e:
        logging.error(f"KeyError in validate_data: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in validate_data: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in validate_data: {e}")
        raise

def plot_yearly_trends_with_metrics(data, metrics, title_prefix="Yearly Trends"):
    """
    Create an interactive plot for yearly trends using Plotly with anomaly detection and clustering.

    Parameters:
    data (pd.DataFrame): Aggregated yearly data containing metrics to plot.
    metrics (list of dict): List of dictionaries where each defines:
        - "column": Column name in the data to plot.
        - "title": Title for the metric (e.g., 'Yearly Total Rainfall').
        - "color": Color for the line plot.
    title_prefix (str): Prefix for the overall plot title.

    Returns:
    pd.DataFrame: DataFrame with additional anomaly and cluster labels.
    """
    try:
        # Validate input data
        if not isinstance(data, pd.DataFrame):
            raise ValueError("The `data` parameter must be a pandas DataFrame.")
        if not all("column" in metric and "title" in metric and "color" in metric for metric in metrics):
            raise ValueError("Each metric in `metrics` must include 'column', 'title', and 'color'.")

        # Check required columns in data
        required_columns = [metric['column'] for metric in metrics] + ['Year']
        missing_columns = [col for col in required_columns if col not in data.columns]
        if missing_columns:
            raise KeyError(f"Missing required columns in data: {missing_columns}")

        # Z-score-based anomaly detection for each metric
        for metric in metrics:
            column = metric["column"]
            data[f"{column}_Zscore"] = zscore(data[column])
            data[f"{column}_Anomaly"] = abs(data[f"{column}_Zscore"]) > 2  # Mark anomalies where |Z-score| > 2

        # Clustering analysis (K-Means) based on selected metrics
        clustering_features = [metric['column'] for metric in metrics]
        kmeans = KMeans(n_clusters=3, random_state=42)
        data['Cluster'] = kmeans.fit_predict(data[clustering_features])

        # Create a subplot for each metric
        fig = make_subplots(
            rows=len(metrics) + 1,  # Extra row for cluster analysis
            cols=1,
            shared_xaxes=True,
            subplot_titles=[m["title"] for m in metrics] + ["Rainfall Clustering"]
        )

        for idx, metric in enumerate(metrics, start=1):
            column = metric["column"]
            title = metric["title"]
            color = metric["color"]

            # Add trace for each metric with anomaly highlighting
            fig.add_trace(
                go.Scatter(
                    x=data['Year'],
                    y=data[column],
                    mode='lines+markers',
                    name=title,
                    line=dict(color=color)
                ),
                row=idx, col=1
            )
            # Highlight anomalies
            anomalies = data[data[f"{column}_Anomaly"]]
            fig.add_trace(
                go.Scatter(
                    x=anomalies['Year'],
                    y=anomalies[column],
                    mode='markers',
                    name=f"{title} Anomalies",
                    marker=dict(color='red', size=10, symbol='x')
                ),
                row=idx, col=1
            )

        # Add clustering results
        fig.add_trace(
            go.Scatter(
                x=data['Year'],
                y=data['Cluster'],
                mode='markers',
                name="Clusters",
                marker=dict(color=data['Cluster'], size=10, colorscale='Viridis')
            ),
            row=len(metrics) + 1, col=1
        )

        # Update layout
        fig.update_layout(
            title=f"{title_prefix}: {data.index.min().year} - {data.index.max().year}",
            xaxis_title="Year",
            height=400 * (len(metrics) + 1),  # Adjust height dynamically
            showlegend=True,
            template="plotly_white"
        )
        fig.update_xaxes(tickformat="%Y")

        # Show the figure
        fig.show()

        return data  # Return data with additional metrics

    except KeyError as e:
        logging.error(f"KeyError in plot_yearly_trends_with_metrics: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_yearly_trends_with_metrics: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_yearly_trends_with_metrics: {e}")
        raise

def evaluate_model(y_actual, y_pred, metrics_to_compute=None):
    """
    Evaluate model predictions with a variety of metrics.

    Parameters:
        y_actual (np.ndarray): True target values.
        y_pred (np.ndarray): Predicted values.
        metrics_to_compute (list): List of metrics to compute. Default includes MAE, MSE, R², SMAPE, RMSE, MAPE, Explained Variance, and MBE.

    Returns:
        dict: A dictionary of evaluation metrics.
    """
    try:
        # Validate input
        if len(y_actual) != len(y_pred):
            raise ValueError("Length of `y_actual` and `y_pred` must match.")
        
        if metrics_to_compute is None:
            metrics_to_compute = [
                "MAE", "MSE", "R²", "SMAPE", "RMSE", "MAPE", 
                "Explained Variance", "MBE"
            ]

        # Initialize metrics dictionary
        metrics = {}

        # Mean Absolute Error (MAE)
        if "MAE" in metrics_to_compute:
            metrics["Mean Absolute Error (MAE)"] = mean_absolute_error(y_actual, y_pred)
        
        # Mean Squared Error (MSE)
        if "MSE" in metrics_to_compute:
            metrics["Mean Squared Error (MSE)"] = mean_squared_error(y_actual, y_pred)
        
        # R² Score
        if "R²" in metrics_to_compute:
            metrics["R² Score"] = r2_score(y_actual, y_pred)
        
        # Symmetric Mean Absolute Percentage Error (SMAPE)
        if "SMAPE" in metrics_to_compute:
            smape = np.mean(2 * np.abs(y_actual - y_pred) / 
                            (np.abs(y_actual) + np.abs(y_pred) + 1e-10)) * 100
            metrics["Symmetric Mean Absolute Percentage Error (SMAPE)"] = smape
        
        # Root Mean Squared Error (RMSE)
        if "RMSE" in metrics_to_compute:
            metrics["Root Mean Squared Error (RMSE)"] = np.sqrt(mean_squared_error(y_actual, y_pred))
        
        # Mean Absolute Percentage Error (MAPE)
        if "MAPE" in metrics_to_compute:
            mape = np.mean(np.abs((y_actual - y_pred) / (y_actual + 1e-10))) * 100
            metrics["Mean Absolute Percentage Error (MAPE)"] = mape
        
        # Explained Variance Score
        if "Explained Variance" in metrics_to_compute:
            explained_variance = 1 - np.var(y_actual - y_pred) / np.var(y_actual)
            metrics["Explained Variance"] = explained_variance
        
        # Mean Bias Error (MBE)
        if "MBE" in metrics_to_compute:
            mbe = np.mean(y_pred - y_actual)
            metrics["Mean Bias Error (MBE)"] = mbe

        return metrics

    except ValueError as e:
        logging.error(f"ValueError in evaluate_model: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in evaluate_model: {e}")
        raise

def plot_distributions(
    df,
    columns_to_plot=None,
    plot_type="both",
    bins=30,
    kde_bw_adjust=1.0,
    save_plots=False,
    output_path="plots/",
    dataset_type="Original",
    model_name=None,
):
    """
    Plot the distributions of rainfall metrics with dynamic analysis and evaluation metrics.

    Parameters:
    df (pd.DataFrame): Input dataset.
    columns_to_plot (list): List of column names to plot. Defaults to 'Monthly_Total' and 'Monthly_Average'.
    plot_type (str): Type of plot ('hist', 'kde', or 'both'). Default is 'both'.
    bins (int): Number of bins for the histogram. Default is 30.
    kde_bw_adjust (float): Bandwidth adjustment for KDE. Default is 1.0.
    save_plots (bool): Whether to save plots to a directory. Default is False.
    output_path (str): Directory path to save plots. Default is "plots/".
    dataset_type (str): Type of dataset (e.g., "Original", "Training", "Predicted").
    model_name (str): Name of the model for visualization. Default is None.

    Returns:
    dict: Evaluation metrics for each column plotted.
    """
    try:
        # Default columns to plot
        if columns_to_plot is None:
            columns_to_plot = ["Monthly_Total", "Monthly_Average"]

        # Ensure columns exist in DataFrame
        missing_columns = [col for col in columns_to_plot if col not in df.columns]
        if missing_columns:
            raise KeyError(f"Missing required columns for plotting: {missing_columns}")

        # Initialize metrics dictionary
        all_metrics = {}

        sns.set(style="whitegrid")

        # Iterate through columns and plot distributions
        for column in columns_to_plot:
            # Ensure the column has valid data
            if df[column].isna().all():
                raise ValueError(f"Column '{column}' contains no valid data to plot.")

            # Calculate metrics
            metrics = {
                "Mean": df[column].mean(),
                "Median": df[column].median(),
                "Skewness": df[column].skew(),
                "Kurtosis": df[column].kurt(),
                "Min": df[column].min(),
                "Max": df[column].max(),
                "Standard Deviation": df[column].std(),
            }
            all_metrics[column] = metrics

            # Print metrics for the column
            logging.info(f"\nMetrics for {column} ({dataset_type}, Model: {model_name or 'N/A'}):")
            for metric, value in metrics.items():
                logging.info(f"  {metric}: {value:.4f}")

            # Plotting
            plt.figure(figsize=(16, 8))
            if plot_type in ["hist", "both"]:
                sns.histplot(df[column], kde=False, bins=bins, label="Histogram", color="blue", alpha=0.7)
            if plot_type in ["kde", "both"]:
                sns.kdeplot(df[column], bw_adjust=kde_bw_adjust, label="KDE", color="red", linestyle="--")

            # Customize title
            plot_title = f"Distribution of {column} ({dataset_type})"
            if model_name:
                plot_title += f" - {model_name}"
            plt.title(plot_title)
            plt.xlabel(f"{column} (mm)")
            plt.ylabel("Frequency")
            plt.legend()
            plt.tight_layout()

            # Save plot if required
            if save_plots:
                os.makedirs(output_path, exist_ok=True)
                plot_filename = f"{output_path}{dataset_type}_{model_name or 'General'}_{column}_distribution.png"
                plt.savefig(plot_filename)
                logging.info(f"Plot saved to: {plot_filename}")

            plt.show()

        return all_metrics

    except KeyError as e:
        logging.error(f"KeyError in plot_distributions: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_distributions: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_distributions: {e}")
        raise

def handle_file(file_input=None, output_dir="output"):
    """
    Handle file input dynamically and ensure it works across environments.

    Parameters:
        file_input (str, FileUpload, or None): 
            - If `str`: Path to the input file (for backend or CLI environments).
            - If `FileUpload`: Jupyter Notebook's interactive file upload object.
            - If `None`: Automatically prompt for file upload in a notebook environment.
        output_dir (str): Directory to save processed files.

    Returns:
        pd.DataFrame: Loaded DataFrame from the input file.
        str: Path to save the processed output file.
    """
    try:
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        if isinstance(file_input, str):
            # Case 1: File path provided as string (Backend/CLI)
            if not os.path.exists(file_input):
                raise FileNotFoundError(f"File not found at: {file_input}")
            df = pd.read_csv(file_input)
            output_file = os.path.join(output_dir, f"processed_{os.path.basename(file_input)}")
            logging.info(f"File loaded successfully from path: {file_input}")
            return df, output_file

        elif isinstance(file_input, FileUpload):
            # Case 2: Interactive file upload in Jupyter Notebook
            uploaded_file = list(file_input.value.values())[0]
            content = uploaded_file["content"]
            file_name = uploaded_file["metadata"]["name"]
            df = pd.read_csv(io.StringIO(content.decode("utf-8")))
            output_file = os.path.join(output_dir, f"processed_{file_name}")
            logging.info(f"File uploaded successfully via Jupyter widget: {file_name}")
            return df, output_file

        elif file_input is None:
            # Case 3: Automatically launch file upload widget in Jupyter Notebook
            print("Please upload a file:")
            upload_widget = FileUpload(accept=".csv", multiple=False)
            display(upload_widget)

            # Wait for file upload with timeout
            import time
            timeout = 60  # Set timeout to 60 seconds
            start_time = time.time()

            while not upload_widget.value:
                if time.time() - start_time > timeout:
                    raise TimeoutError("File upload timed out. Please try again.")
                time.sleep(0.5)  # Prevent busy-waiting

            logging.info("File uploaded via Jupyter widget and detected successfully.")

            # Process the uploaded file without recursion
            uploaded_file = list(upload_widget.value.values())[0]
            content = uploaded_file["content"]
            file_name = uploaded_file["metadata"]["name"]
            df = pd.read_csv(io.StringIO(content.decode("utf-8")))
            output_file = os.path.join(output_dir, f"processed_{file_name}")
            return df, output_file

        else:
            raise ValueError("Invalid file input type. Provide a valid file path or upload widget.")

    except TimeoutError as te:
        logging.error(f"TimeoutError: {te}")
        raise
    except Exception as e:
        logging.error(f"Error handling file: {e}")
        raise

def create_month_column(df):
    """
    Create a 'Month' column using 'YEAR' and 'Month' columns.

    Parameters:
        df (pd.DataFrame): Input DataFrame containing 'YEAR' and 'Month'.

    Returns:
        pd.DataFrame: DataFrame with a 'Month' column in datetime format.

    Raises:
        KeyError: If 'YEAR' or 'Month' columns are missing.
    """
    try:
        if 'YEAR' in df.columns and 'Month' in df.columns:
            df['Month'] = pd.to_datetime(df[['YEAR', 'Month']].assign(day=1), errors='coerce')
            if df['Month'].isna().any():
                raise ValueError("Some 'Month' values could not be converted to datetime.")
            return df
        else:
            raise KeyError("Missing 'YEAR' or 'Month' columns required for datetime conversion.")
    except Exception as e:
        logging.error(f"Failed to create 'Month' column: {e}")
        raise

def validate_year_month(df):
    """
    Validate and clean YEAR and Month columns for datetime conversion.

    Parameters:
    df (pd.DataFrame): Input DataFrame.

    Returns:
    pd.DataFrame: DataFrame with cleaned YEAR and Month columns.
    """
    try:
        # Validate YEAR column
        if 'YEAR' in df.columns:
            invalid_years = df[~df['YEAR'].apply(lambda x: isinstance(x, (int, float)) and 1800 <= x <= 2100)]
            if not invalid_years.empty:
                logging.warning(f"Invalid YEAR entries found: {invalid_years['YEAR'].unique()}")
            df['YEAR'] = pd.to_numeric(df['YEAR'], errors='coerce').astype('Int64')  # Ensure YEAR remains integers
        
        # Validate and convert Month column to datetime
        if 'Month' in df.columns:
            invalid_months = df[~df['Month'].apply(lambda x: isinstance(x, (int, float)) and 1 <= x <= 12)]
            if not invalid_months.empty:
                logging.warning(f"Invalid Month entries found: {invalid_months['Month'].unique()}")
            df['Month'] = pd.to_numeric(df['Month'], errors='coerce')  # Ensure Month is numeric

            # Combine YEAR and Month into a datetime column
            df['Month'] = pd.to_datetime(df[['YEAR', 'Month']].assign(day=1), errors='coerce')

        # Log final validation state
        logging.info(f"YEAR column type: {df['YEAR'].dtype}")
        logging.info(f"Month column type: {df['Month'].dtype}")

        return df

    except Exception as e:
        logging.error(f"Error in validate_year_month: {e}")
        raise


def save_dataset(df, output_path, save_file=False):
    """
    Optionally save the cleaned dataset to a specified file path and provide a download link.

    Parameters:
    df (pd.DataFrame): Dataset to save.
    output_path (str): Path to save the dataset, including the file name.
    save_file (bool): Whether to save the dataset to a file. Default is False.

    Returns:
    str: Path to the saved file or a message indicating it was not saved.
    """
    try:
        if not save_file:
            logging.info("File saving is skipped by user preference.")
            return "File saving skipped."

        # Validate input parameters
        if not isinstance(output_path, str) or not output_path.endswith('.csv'):
            raise ValueError("Output path must be a valid CSV file path.")

        if df.empty:
            raise ValueError("The dataset is empty. Nothing to save.")

        # Save the dataset
        df.to_csv(output_path, index=False)
        logging.info(f"Processed dataset saved successfully at {output_path}")

        # Provide download link in Jupyter Notebook
        if "IPython" in globals():
            return FileLink(output_path)

        return output_path

    except PermissionError:
        logging.error(f"Permission denied: Cannot save to {output_path}. Check write permissions.")
        raise
    except FileNotFoundError:
        logging.error(f"FileNotFoundError: The directory for {output_path} does not exist.")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in save_dataset: {str(e)}")
        raise

# ----------------------- Main Workflow ----------------------- #

#Step 1: Load the dataset

file_path = "/Users/feysel/Developer/Practices/Python/refined/precipitation_data.csv"  # Replace with your file path
df = pd.read_csv(file_path)
print(df.head())

#Step 1.1: Create 'Month' column
df = create_month_column(df)
logging.info("Created 'Month' column successfully.")

# try:
#     # Choose the file input method
#     # If you want to use a static file path, set it here
#     static_file_path = "/Users/feysel/Developer/Practices/Python/refined/precipitation_data.csv"  # Replace with the actual path

#     if static_file_path:
#         # Use the static file path if specified
#         input_file = static_file_path
#         logging.info(f"Using static file path: {static_file_path}")
#     else:
#         # Default to the dynamic file handling function
#         if 'get_ipython' not in globals():
#             # CLI environment: prompt for a file path
#             input_file = input("Enter the path to your dataset (CSV format): ").strip()
#             if not os.path.exists(input_file):
#                 raise FileNotFoundError(f"File not found at the specified path: {input_file}")
#         else:
#             # Jupyter Notebook environment: use upload widget
#             input_file = None  # This triggers the upload widget in handle_file

#     # Load the dataset
#     df, output_file = handle_file(file_input=input_file)
#     if df.empty:
#         raise ValueError("The loaded dataset is empty. Aborting further steps.")
    
#     logging.info(f"Dataset loaded successfully with {df.shape[0]} rows and {df.shape[1]} columns.")

# except FileNotFoundError as e:
#     logging.error(f"File error: {e}")
#     raise
# except Exception as e:
#     logging.error(f"Unexpected error while handling the file: {e}")
#     raise


#Step 2: Handle missing values and redundant columns
logging.info("Handling missing values...")
df = handle_missing_values(df, drop_cells=True, drop_rows=False, threshold=0.5)
if 'Month' not in df.columns:
    df = create_month_column(df)  # Recreate if dropped
logging.info(f"Missing values handled. Dataset now has {df.isna().sum().sum()} missing cells.")


print("Columns in DataFrame:", df.columns)

# Step 3: Extract Features and Metrics
logging.info("Performing feature engineering...")
try:
    df, metadata, daily_columns = extract_features_and_metrics(
        df,
        daily_columns_prefix="",  # Leave blank if no specific prefix exists
        rolling_window=12,
        validate_values=True
    )
    logging.info(f"Feature engineering completed. Metadata: {metadata}")
except Exception as e:
    logging.error(f"Error during feature extraction: {e}")
    raise

#Step 4: Remove redundant columns
logging.info("Removing redundant columns...")
df = drop_columns(df, columns_to_drop=daily_columns)
logging.info(f"Redundant columns removed. Dataset now has {df.shape[1]} columns.")

#Step 5: Validate the dataset
logging.info("Validating the dataset...")
validation_summary = validate_data(df)
if validation_summary['Missing Values'] > 0:
    logging.warning("Dataset contains missing values after handling. Review required.")
logging.info(f"Validation completed. Summary: {validation_summary}")

#Step 6: Visualize distributions
logging.info("Visualizing distributions...")
plot_metrics = plot_distributions(
    df, 
    columns_to_plot=["Monthly_Total", "Monthly_Average"], 
    plot_type="both", 
    save_plots=True, 
    output_path="plots/",
    dataset_type="Processed"
)
logging.info(f"Distribution visualizations saved successfully.")

# Step 7: Aggregate data by year with new metrics
try:
    # Log columns for debugging
    print(f"Dataset loaded successfully with {df.shape[0]} rows and {df.shape[1]} columns.")
    
    # Validate time series index
    if 'Month' in df.columns and not isinstance(df.index, pd.DatetimeIndex):
        df.set_index('Month', inplace=True)

    # Resample by year using the updated 'YE' frequency
    yearly_data = df.resample('YE').agg({
        'Monthly_Total': 'sum',  # Total rainfall per year
        'Monthly_Average': 'mean'  # Average monthly rainfall per year
    })

    # Add 'Year' column for plotting, ensuring it is datetime
    yearly_data['Year'] = yearly_data.index.year.astype(int)

    # Validate yearly data for issues
    if yearly_data.isnull().any().any():
        logging.warning("Yearly data contains missing values.")
    if (yearly_data[['Monthly_Total', 'Monthly_Average']] <= 0).any().any():
        logging.warning("Yearly data contains zero or negative values.")

    # Define metrics for interactive plotting
    metrics_to_plot = [
        {"column": "Monthly_Total", "title": "Yearly Total Rainfall", "color": "blue"},
        {"column": "Monthly_Average", "title": "Yearly Average Rainfall", "color": "green"}
    ]

    # Plot enhanced interactive yearly trends
    plot_yearly_trends_interactive(yearly_data, metrics_to_plot)

except Exception as e:
    logging.error(f"Failed to aggregate data by year: {e}")
    raise

# Step 8: Save the processed dataset (optional)
save_file = False  # Set to False to skip saving
if save_file:
    logging.info("Saving the processed dataset...")
    save_result = save_dataset(df, output_file, save_file=save_file)
    if isinstance(save_result, FileLink):
        display(save_result)  # Display download link in Jupyter
    logging.info(f"Processed dataset saved to {output_file}.")
else:
    logging.info("Skipping dataset save as per user preference.")

# --------------------------------- 1: DATA VISUALIZATION -------------------------------- #
def load_visualization_data(file_path=None, df=None):
    """
    Load the processed dataset for visualization.

    Parameters:
    file_path (str): File path to load the processed dataset.
    df (pd.DataFrame or None): Optionally provide a preloaded DataFrame.

    Returns:
    pd.DataFrame: Processed DataFrame ready for visualization.
    """
    try:
        # Validate input
        if df is not None:
            logging.info("Using provided DataFrame for visualization.")
            if not isinstance(df, pd.DataFrame):
                raise ValueError("Provided `df` is not a valid DataFrame.")
            return df

        if file_path is None:
            raise ValueError("Either `file_path` or `df` must be provided.")

        # Load the dataset dynamically
        df = pd.read_csv(file_path)
        if df.empty:
            raise ValueError(f"The dataset at {file_path} is empty.")

        # Ensure 'Month' column is in datetime format
        if 'Month' not in df.columns:
            raise KeyError("'Month' column is missing in the dataset.")
        df['Month'] = pd.to_datetime(df['Month'])

        logging.info(f"Dataset for visualization successfully loaded from: {file_path}")
        return df

    except FileNotFoundError:
        logging.error(f"File not found at: {file_path}")
        raise
    except ValueError as ve:
        logging.error(f"ValueError in load_visualization_data: {ve}")
        raise
    except KeyError as ke:
        logging.error(f"KeyError in load_visualization_data: {ke}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in load_visualization_data: {str(e)}")
        raise

# -------------------- STEP 1: TIME-SERIES PLOTS -------------------- #
def plot_time_series(
    df,
    column_to_plot="Monthly_Total",
    rolling_column="Rolling_Average",
    time_span_years=5,
    figsize=(16, 8),
    save_plots=False,
    output_path="plots/",
    interactive=False,
    include_metrics=True,
):
    """
    Visualize rainfall trends across specific time spans and the entire timeline.

    Parameters:
    df (pd.DataFrame): DataFrame containing the data.
    column_to_plot (str): Column name for the main metric to plot (e.g., Monthly_Total).
    rolling_column (str): Column name for the rolling average to overlay.
    time_span_years (int): Number of years for each time-span subset.
    figsize (tuple): Figure size for plots.
    save_plots (bool): Whether to save the plots to files.
    output_path (str): Directory to save the plots if enabled.
    interactive (bool): Whether to create interactive plots using Plotly.
    include_metrics (bool): Whether to include aggregate metrics on the plots.

    Returns:
    None
    """
    try:
        # Validate input DataFrame
        if column_to_plot not in df.columns or rolling_column not in df.columns:
            raise KeyError(f"Columns {column_to_plot} or {rolling_column} not found in DataFrame.")

        # Ensure 'Month' column is in datetime format
        if not pd.api.types.is_datetime64_any_dtype(df['Month']):
            raise ValueError("The 'Month' column must be a datetime type.")

        # Calculate time spans dynamically
        min_year = df['Month'].dt.year.min()
        max_year = df['Month'].dt.year.max()
        total_years = max_year - min_year + 1
        subsets_count = (total_years + time_span_years - 1) // time_span_years  # Round up

        for i in range(subsets_count):
            start_year = min_year + i * time_span_years
            end_year = min(start_year + time_span_years - 1, max_year)

            # Subset data
            subset = df[(df['Month'].dt.year >= start_year) & (df['Month'].dt.year <= end_year)]
            if subset.empty:
                logging.warning(f"No data found for the time span: {start_year} - {end_year}")
                continue

            # Initialize the plot
            plt.figure(figsize=figsize)

            # Plot the main metric and rolling average
            plt.plot(subset['Month'], subset[column_to_plot], label=column_to_plot, color='blue')
            plt.plot(subset['Month'], subset[rolling_column], label=f"{rolling_column} (Rolling)", color='green')

            # Include metrics if enabled
            if include_metrics:
                total_mean = subset[column_to_plot].mean()
                total_median = subset[column_to_plot].median()
                plt.axhline(total_mean, color='red', linestyle='--', linewidth=1, label=f'Mean: {total_mean:.2f}')
                plt.axhline(total_median, color='orange', linestyle='--', linewidth=1, label=f'Median: {total_median:.2f}')

            # Customize x-axis for legibility
            plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y'))  # Display years only
            plt.gca().xaxis.set_major_locator(mdates.YearLocator(1))  # Major ticks every year
            plt.gca().xaxis.set_minor_locator(mdates.MonthLocator())  # Minor ticks every month
            plt.grid(which='minor', linestyle=':', linewidth=0.5, color='gray', alpha=0.5)

            # Add titles and labels
            plt.title(f"Rainfall Trends: {start_year} to {end_year}")
            plt.xlabel("Year")
            plt.ylabel("Rainfall (mm)")
            plt.legend()
            plt.xticks(rotation=45)
            plt.tight_layout()

            # Save the plot if required
            if save_plots:
                os.makedirs(output_path, exist_ok=True)
                plot_filename = f"{output_path}rainfall_trend_{start_year}_{end_year}.png"
                plt.savefig(plot_filename)
                logging.info(f"Plot saved to: {plot_filename}")

            # Show the plot
            if interactive:
                import plotly.express as px
                fig = px.line(subset, x='Month', y=[column_to_plot, rolling_column],
                              labels={"value": "Rainfall (mm)", "variable": "Metric"},
                              title=f"Rainfall Trends ({start_year}-{end_year})")
                fig.show()
            else:
                plt.show()

        # Plot the entire timeline
        plt.figure(figsize=figsize)
        plt.plot(df['Month'], df[column_to_plot], label=column_to_plot, color='blue')
        plt.plot(df['Month'], df[rolling_column], label=f"{rolling_column} (Rolling)", color='green')

        # Include metrics if enabled
        if include_metrics:
            total_mean = df[column_to_plot].mean()
            total_median = df[column_to_plot].median()
            plt.axhline(total_mean, color='red', linestyle='--', linewidth=1, label=f'Mean: {total_mean:.2f}')
            plt.axhline(total_median, color='orange', linestyle='--', linewidth=1, label=f'Median: {total_median:.2f}')

        plt.title("Rainfall Trends: Full Timeline")
        plt.xlabel("Year")
        plt.ylabel("Rainfall (mm)")
        plt.legend()
        plt.xticks(rotation=45)
        plt.tight_layout()

        # Save the plot if required
        if save_plots:
            full_timeline_filename = f"{output_path}rainfall_trend_full_timeline.png"
            plt.savefig(full_timeline_filename)
            logging.info(f"Full timeline plot saved to: {full_timeline_filename}")

        # Show the plot
        if interactive:
            fig = px.line(df, x='Month', y=[column_to_plot, rolling_column],
                          labels={"value": "Rainfall (mm)", "variable": "Metric"},
                          title="Rainfall Trends: Full Timeline")
            fig.show()
        else:
            plt.show()

    except KeyError as e:
        logging.error(f"KeyError in plot_time_series: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_time_series: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_time_series: {e}")
        raise

# -------------------- STEP 2: YEARLY TRENDS -------------------- #
def plot_yearly_trends_interactive(data, metrics, title_prefix="Yearly Trends"):
    """
    Create an interactive plot for yearly trends using Plotly with robust error handling.

    Parameters:
    data (pd.DataFrame): Aggregated yearly data containing metrics to plot.
    metrics (list of dict): List of dictionaries where each defines:
        - "column": Column name in the data to plot.
        - "title": Title for the metric (e.g., 'Yearly Total Rainfall').
        - "color": Color for the line plot.
    title_prefix (str): Prefix for the overall plot title.

    Returns:
    None
    """
    try:
        # Validate input data
        if not isinstance(data, pd.DataFrame):
            raise ValueError("The `data` parameter must be a pandas DataFrame.")
        if not all("column" in metric and "title" in metric and "color" in metric for metric in metrics):
            raise ValueError("Each metric in `metrics` must include 'column', 'title', and 'color'.")

        # Check required columns in data
        required_columns = [metric['column'] for metric in metrics] + ['Year']
        missing_columns = [col for col in required_columns if col not in data.columns]
        if missing_columns:
            raise KeyError(f"Missing required columns in data: {missing_columns}")

        # Check for invalid or missing data
        if data.isnull().any().any():
            logging.warning("The data contains missing values. Please investigate.")
    
        # Validate for negative values in numeric columns only
        numeric_columns = data.select_dtypes(include=[np.number]).columns
        for column in numeric_columns:
            if (data[column] < 0).any():
                logging.warning(f"The column '{column}' contains negative values. Ensure data validity.")

        # Create a subplot for each metric
        fig = make_subplots(rows=len(metrics), cols=1, shared_xaxes=True, subplot_titles=[m["title"] for m in metrics])

        for idx, metric in enumerate(metrics, start=1):
            column = metric["column"]
            title = metric["title"]
            color = metric["color"]

            # Add trace for each metric
            fig.add_trace(
                go.Scatter(
                    x=data['Year'],
                    y=data[column],
                    mode='lines+markers',
                    name=title,
                    line=dict(color=color)
                ),
                row=idx, col=1
            )

        # Update layout
        fig.update_layout(
            title=f"{title_prefix}: {data['Year'].min()} - {data['Year'].max()}",
            xaxis_title="Year",
            height=300 * len(metrics),
            showlegend=True,
            template="plotly_white"
        )
        fig.update_xaxes(tickformat="%Y")

        # Show the figure
        fig.show()

    except KeyError as e:
        logging.error(f"KeyError in plot_yearly_trends_interactive: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_yearly_trends_interactive: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_yearly_trends_interactive: {e}")
        raise

# Validate time series index
if 'Month' in df.columns and not isinstance(df.index, pd.DatetimeIndex):
    df.set_index('Month', inplace=True)

# Aggregate data by year
try:
    yearly_data = df.resample('Y').agg({
        'Monthly_Total': 'sum',  # Total rainfall per year
        'Monthly_Average': 'mean'  # Average monthly rainfall per year
    })

    # Add 'Year' column for plotting
    yearly_data['Year'] = yearly_data.index

    # Validate yearly data for issues
    if yearly_data.isnull().any().any():
        logging.warning("Yearly data contains missing values.")
    if (yearly_data[['Monthly_Total', 'Monthly_Average']] <= 0).any().any():
        logging.warning("Yearly data contains zero or negative values.")

except Exception as e:
    logging.error(f"Error during yearly data aggregation: {e}")
    raise

# Define metrics for interactive plotting
metrics_to_plot = [
    {"column": "Monthly_Total", "title": "Yearly Total Rainfall", "color": "blue"},
    {"column": "Monthly_Average", "title": "Yearly Average Rainfall", "color": "green"}
]

# Plot enhanced interactive yearly trends
try:
    plot_yearly_trends_interactive(yearly_data, metrics_to_plot)
except Exception as e:
    logging.error(f"Failed to plot yearly trends: {e}")
    raise

# -------------------- STEP 3: INTERACTIVE BOX PLOTS -------------------- #
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def plot_boxplots_interactive(data, grouping_column, value_column, title, x_labels=None, annotation=None, color='lightblue'):
    """
    Create an interactive box plot using Plotly.

    Parameters:
    data (pd.DataFrame): DataFrame containing the data to plot.
    grouping_column (str): Column to group data by (e.g., 'Year', 'Month_Number').
    value_column (str): Column containing the values to plot.
    title (str): Title of the box plot.
    x_labels (list, optional): Custom labels for the x-axis (e.g., month names for 'Month_Number').
    annotation (str, optional): Annotation text to describe the box plot features.
    color (str): Color of the box plot.

    Returns:
    None
    """
    try:
        # Validate inputs
        if grouping_column not in data.columns or value_column not in data.columns:
            raise KeyError(f"Columns '{grouping_column}' or '{value_column}' are missing in the dataset.")

        if data[value_column].isnull().any():
            logging.warning(f"Column '{value_column}' contains missing values.")

        # Create a box plot
        fig = go.Figure()

        # Group data for the box plot
        for group in sorted(data[grouping_column].unique()):
            group_data = data[data[grouping_column] == group][value_column]
            fig.add_trace(go.Box(
                y=group_data,
                name=str(group),
                boxpoints='all',  # Show all points
                jitter=0.5,  # Add jitter for visibility
                whiskerwidth=0.5,
                fillcolor=color,
                line=dict(width=1),
                marker=dict(size=3),
                showlegend=False
            ))

        # Set plot layout
        fig.update_layout(
            title=title,
            xaxis=dict(title=grouping_column, tickvals=list(range(len(x_labels))) if x_labels else None, ticktext=x_labels),
            yaxis=dict(title=value_column),
            template="plotly_white"
        )

        # Add annotation if provided
        if annotation:
            fig.add_annotation(
                text=annotation,
                align='left',
                showarrow=False,
                xref='paper',
                yref='paper',
                x=1.05,
                y=0.5,
                bordercolor='black',
                borderwidth=1,
                bgcolor='lightgray',
                font=dict(size=10)
            )

        # Show the plot
        fig.show()

    except KeyError as e:
        logging.error(f"KeyError in plot_boxplots_interactive: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_boxplots_interactive: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_boxplots_interactive: {e}")
        raise

# -------------------- STEP 4: INTERACTIVE HISTOGRAMS -------------------- #
import plotly.express as px

def plot_histograms_interactive(data, columns, titles, bins=30, colors=None):
    """
    Create interactive histograms using Plotly for analyzing data distributions.

    Parameters:
    data (pd.DataFrame): The dataset containing the columns to plot.
    columns (list): List of column names to create histograms for.
    titles (list): Titles for each histogram.
    bins (int): Number of bins for histograms.
    colors (list, optional): List of colors for the histograms.

    Returns:
    None
    """
    try:
        # Validate inputs
        if len(columns) != len(titles):
            raise ValueError("The `columns` and `titles` lists must have the same length.")
        if colors and len(columns) != len(colors):
            raise ValueError("If `colors` are provided, their length must match `columns` and `titles`.")

        # Loop through each column to plot
        for idx, column in enumerate(columns):
            if column not in data.columns:
                raise KeyError(f"Column '{column}' not found in the dataset.")
            if data[column].isnull().any():
                logging.warning(f"Column '{column}' contains missing values.")

            # Create the histogram
            fig = px.histogram(
                data_frame=data,
                x=column,
                nbins=bins,
                title=titles[idx],
                color_discrete_sequence=[colors[idx]] if colors else None,
                marginal="box",  # Add a box plot for better distribution analysis
            )

            # Update layout
            fig.update_layout(
                xaxis_title=f"{column} (mm)",
                yaxis_title="Frequency",
                template="plotly_white"
            )

            # Show the plot
            fig.show()

    except KeyError as e:
        logging.error(f"KeyError in plot_histograms_interactive: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in plot_histograms_interactive: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in plot_histograms_interactive: {e}")
        raise

# -------------------- STEP 5: ENHANCED STATIONARITY CHECK -------------------- #
def perform_stationarity_check(data, column, window=12):
    """
    Perform enhanced stationarity checks using rolling statistics and statistical tests.

    Parameters:
    data (pd.DataFrame): Input dataset.
    column (str): Column to perform the stationarity check on.
    window (int): Window size for rolling statistics.

    Returns:
    dict: Results of the statistical tests.
    """
    try:
        # Validate input
        if column not in data.columns:
            raise KeyError(f"Column '{column}' not found in the dataset.")
        if data[column].isnull().any():
            logging.warning(f"Column '{column}' contains missing values.")

        # Suppress statsmodels warnings
        warnings.filterwarnings("ignore", category=UserWarning)

        # Calculate rolling mean and standard deviation
        rolling_mean = data[column].rolling(window=window).mean()
        rolling_std = data[column].rolling(window=window).std()

        # Interactive Rolling Statistics Plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=data['Month'], y=data[column],
            mode='lines', name='Original Data', line=dict(color='blue')
        ))
        fig.add_trace(go.Scatter(
            x=data['Month'], y=rolling_mean,
            mode='lines', name=f'Rolling Mean ({window} months)', line=dict(color='orange')
        ))
        fig.add_trace(go.Scatter(
            x=data['Month'], y=rolling_std,
            mode='lines', name=f'Rolling Std Dev ({window} months)', line=dict(color='green')
        ))
        fig.update_layout(
            title="Rolling Statistics (Interactive)",
            xaxis_title="Time",
            yaxis_title="Rainfall (mm)",
            template="plotly_white"
        )
        fig.show()

        # Statistical Tests
        adf_result = adfuller(data[column].dropna())
        try:
            kpss_result = kpss(data[column].dropna(), regression='c', nlags="legacy")
        except ValueError as e:
            kpss_result = (None, None)
            logging.warning(f"KPSS Test Error: {e}")

        pp_result = PhillipsPerron(data[column].dropna())
        za_result = zivot_andrews(data[column].dropna())

        # Store results in a dictionary
        results = {
            "ADF Test": {"Statistic": adf_result[0], "p-value": adf_result[1]},
            "KPSS Test": {
                "Statistic": kpss_result[0] if kpss_result[0] is not None else "Unavailable",
                "p-value": kpss_result[1] if kpss_result[1] is not None else "Unavailable",
            },
            "Phillips-Perron Test": {"Statistic": pp_result.stat, "p-value": pp_result.pvalue},
            "Zivot-Andrews Test": {"Statistic": za_result[0], "p-value": za_result[1]},
        }

        # Log results
        for test, result in results.items():
            logging.info(f"{test}: Statistic={result['Statistic']}, p-value={result['p-value']}")

        return results

    except KeyError as e:
        logging.error(f"KeyError in perform_stationarity_check: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in perform_stationarity_check: {e}")
        raise

# ------------------------------- DATA PREPARATION ------------------------------- #

# ------------------------------- STEP 1: LOAD AND PREPARE DATA ------------------------------- #
def prepare_data(file_path, scaler_path='scaler.pkl', features_to_keep=None, output_dir='output'):
    """
    Load, validate, and prepare the dataset for model training.

    Parameters:
    file_path (str): Path to the preprocessed dataset.
    scaler_path (str): Path to save the MinMaxScaler object for inverse transformations.
    features_to_keep (list, optional): List of columns to retain in the prepared dataset.
    output_dir (str): Directory to save intermediate files and logs.

    Returns:
    pd.DataFrame: Prepared and scaled dataset.
    MinMaxScaler: Fitted scaler object for further transformations.
    """
    try:
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Validate input
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found at: {file_path}")

        logging.info(f"Loading dataset from {file_path}...")
        df = pd.read_csv(file_path)

        # Ensure 'Month' column is in datetime format
        if 'Month' not in df.columns:
            raise KeyError("'Month' column is missing in the dataset.")
        df['Month'] = pd.to_datetime(df['Month'])

        # Select relevant features
        default_features = ['Month', 'Monthly_Total', 'Rolling_Average']
        features_to_keep = features_to_keep or default_features
        missing_features = [col for col in features_to_keep if col not in df.columns]
        if missing_features:
            raise KeyError(f"Missing required columns: {missing_features}")
        df_selected = df[features_to_keep]

        # Validate selected features
        numerical_columns = [col for col in features_to_keep if col != 'Month']
        for col in numerical_columns:
            if not np.issubdtype(df_selected[col].dtype, np.number):
                raise ValueError(f"Column '{col}' contains non-numeric data.")
            if df_selected[col].isnull().any():
                logging.warning(f"Column '{col}' contains missing values.")
            if (df_selected[col] < 0).any():
                logging.warning(f"Column '{col}' contains negative values.")

        # Normalize numerical features
        logging.info(f"Normalizing features: {numerical_columns}")
        scaler = MinMaxScaler()
        df_selected[numerical_columns] = scaler.fit_transform(df_selected[numerical_columns])

        # Save the scaler
        scaler_path = os.path.join(output_dir, scaler_path)
        joblib.dump(scaler, scaler_path)
        logging.info(f"Scaler saved at {scaler_path}")

        # Log and return
        logging.info(f"Data preparation completed. Dataset contains {df_selected.shape[0]} rows and {df_selected.shape[1]} columns.")
        return df_selected, scaler

    except FileNotFoundError as e:
        logging.error(f"FileNotFoundError: {e}")
        raise
    except KeyError as e:
        logging.error(f"KeyError: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in prepare_data: {e}")
        raise

# ------------------------------- STEP 2: CREATE SEQUENCES ------------------------------- #
def create_sequences(data, target_column, sequence_length, output_dir=None):
    """
    Create overlapping sequences for time-series data.

    Parameters:
    data (pd.DataFrame): Preprocessed and scaled DataFrame containing features.
    target_column (str): The name of the target column.
    sequence_length (int): Number of time steps in each sequence.
    output_dir (str, optional): Directory to save the generated sequences. If None, sequences are not saved.

    Returns:
    np.ndarray: Input sequences (samples, sequence_length, features).
    np.ndarray: Target values for each sequence.
    list: List of columns used as features.
    """
    try:
        # Validate inputs
        if target_column not in data.columns:
            raise KeyError(f"Target column '{target_column}' is not in the dataset.")
        if sequence_length <= 0:
            raise ValueError("Sequence length must be greater than 0.")
        if len(data) <= sequence_length:
            raise ValueError("Dataset size must be larger than the sequence length.")

        # Select feature columns
        feature_columns = [col for col in data.columns if col not in [target_column, 'Month']]
        if not feature_columns:
            raise ValueError("No feature columns available after excluding target and 'Month' columns.")

        # Generate sequences
        x, y = [], []
        for i in range(len(data) - sequence_length):
            x.append(data.iloc[i:i + sequence_length][feature_columns].values)
            y.append(data.iloc[i + sequence_length][target_column])
        x = np.array(x)
        y = np.array(y)

        # Log sequence details
        logging.info(f"Sequences created with shape: x={x.shape}, y={y.shape}")
        logging.info(f"Feature columns used: {feature_columns}")

        # Optionally save sequences
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            np.save(os.path.join(output_dir, 'x_sequences.npy'), x)
            np.save(os.path.join(output_dir, 'y_sequences.npy'), y)
            logging.info(f"Sequences saved to directory: {output_dir}")

        return x, y, feature_columns

    except KeyError as e:
        logging.error(f"KeyError in create_sequences: {e}")
        raise
    except ValueError as e:
        logging.error(f"ValueError in create_sequences: {e}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error in create_sequences: {e}")
        raise

# ------------------------------- STEP 3: TRAIN-TEST SPLIT ------------------------------- #
def split_data(x, y, train_size=0.7, validation_split=0.1):
    """
    Split the data into training, validation, and testing sets.

    Parameters:
    x (np.ndarray): Input sequences (samples, sequence_length, features).
    y (np.ndarray): Target values for each sequence.
    train_size (float): Fraction of data to use for training (default is 70%).
    validation_split (float): Fraction of training data to use for validation (default is 10%).

    Returns:
    dict: Dictionary containing train, validation, and test splits for x and y.
    """
    try:
        # Validate split ratios
        if not (0 < train_size < 1):
            raise ValueError("`train_size` must be a float between 0 and 1.")
        if not (0 <= validation_split < 1):
            raise ValueError("`validation_split` must be a float between 0 and 1.")

        # Determine split indices
        train_end = int(len(x) * train_size)
        x_train, x_test = x[:train_end], x[train_end:]
        y_train, y_test = y[:train_end], y[train_end:]

        # Further split training data for validation
        if validation_split > 0:
            val_end = int(len(x_train) * (1 - validation_split))
            x_train, x_val = x_train[:val_end], x_train[val_end:]
            y_train, y_val = y_train[:val_end], y_train[val_end:]
        else:
            x_val, y_val = None, None

        # Log the data split sizes
        logging.info(f"Train data: {x_train.shape}, {y_train.shape}")
        if x_val is not None:
            logging.info(f"Validation data: {x_val.shape}, {y_val.shape}")
        logging.info(f"Test data: {x_test.shape}, {y_test.shape}")

        return {
            "x_train": x_train, "y_train": y_train,
            "x_val": x_val, "y_val": y_val,
            "x_test": x_test, "y_test": y_test
        }

    except Exception as e:
        logging.error(f"Error in data splitting: {e}")
        raise

# Perform the split
split_data_dict = split_data(x, y, train_size=0.7, validation_split=0.1)

# Extract the splits for convenience
x_train = split_data_dict["x_train"]
y_train = split_data_dict["y_train"]
x_val = split_data_dict["x_val"]
y_val = split_data_dict["y_val"]
x_test = split_data_dict["x_test"]
y_test = split_data_dict["y_test"]

# Reshape for LSTM/CNN if needed
x_train_reshaped = x_train.reshape((x_train.shape[0], x_train.shape[1], x_train.shape[2]))
x_test_reshaped = x_test.reshape((x_test.shape[0], x_test.shape[1], x_test.shape[2]))
if x_val is not None:
    x_val_reshaped = x_val.reshape((x_val.shape[0], x_val.shape[1], x_val.shape[2]))

# Save metadata for reproducibility
metadata = {
    "train_size": len(x_train),
    "val_size": len(x_val) if x_val is not None else 0,
    "test_size": len(x_test),
    "sequence_length": x_train.shape[1],
    "num_features": x_train.shape[2]
}
joblib.dump(metadata, "data_split_metadata.pkl")
logging.info(f"Data split metadata saved: {metadata}")

# ------------------------------- MODEL TRAINING ------------------------------- #

# ------------------------- OPTUNA-LSTM HYPERPARAMETER OPTIMIZATION ------------------------- #
def objective_lstm(trial, x_train, y_train, x_test, y_test):
    """
    Objective function for LSTM hyperparameter optimization using Optuna.

    Parameters:
        trial (optuna.trial.Trial): Optuna trial object.
        x_train (np.ndarray): Training input data.
        y_train (np.ndarray): Training target data.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test target data.

    Returns:
        float: Validation loss for the trial.
    """
    try:
        # Define hyperparameter search space
        lstm_units = trial.suggest_int("lstm_units", 32, 128, step=16)
        dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5, step=0.1)
        learning_rate = trial.suggest_loguniform("learning_rate", 1e-4, 1e-2)
        optimizer_name = trial.suggest_categorical("optimizer", ["adam", "rmsprop"])
        batch_size = trial.suggest_int("batch_size", 16, 64, step=16)
        epochs = 50

        # Build the model
        model = Sequential([
            LSTM(units=lstm_units, activation='relu', input_shape=(x_train.shape[1], x_train.shape[2])),
            Dropout(rate=dropout_rate),
            Dense(1)
        ])
        optimizer = Adam(learning_rate=learning_rate) if optimizer_name == "adam" else RMSprop(learning_rate=learning_rate)
        model.compile(optimizer=optimizer, loss=MeanSquaredError())

        # Train the model with early stopping
        early_stopping = EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
        history = model.fit(
            x_train, y_train,
            validation_split=0.2,
            epochs=epochs,
            batch_size=batch_size,
            callbacks=[early_stopping],
            verbose=0
        )

        # Return the minimum validation loss
        return min(history.history['val_loss'])

    except Exception as e:
        logging.error(f"Error during trial optimization: {str(e)}")
        raise


def optimize_lstm(x_train, y_train, x_test, y_test, n_trials=50):
    """
    Optimize LSTM hyperparameters using Optuna.

    Parameters:
        x_train (np.ndarray): Training input data.
        y_train (np.ndarray): Training target data.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test target data.
        n_trials (int): Number of optimization trials.

    Returns:
        dict: Best hyperparameters found by Optuna.
        keras.Model: Trained LSTM model with best hyperparameters.
    """
    try:
        # Define the study and optimize
        study = optuna.create_study(direction="minimize")
        study.optimize(lambda trial: objective_lstm(trial, x_train, y_train, x_test, y_test), n_trials=n_trials)

        logging.info(f"Best trial: {study.best_trial.number}")
        logging.info(f"Best hyperparameters: {study.best_params}")

        # Train the best model with the best parameters
        best_params = study.best_params
        final_model = Sequential([
            LSTM(units=best_params["lstm_units"], activation='relu', input_shape=(x_train.shape[1], x_train.shape[2])),
            Dropout(rate=best_params["dropout_rate"]),
            Dense(1)
        ])
        final_optimizer = Adam(learning_rate=best_params["learning_rate"]) if best_params["optimizer"] == "adam" else RMSprop(learning_rate=best_params["learning_rate"])
        final_model.compile(optimizer=final_optimizer, loss=MeanSquaredError())

        final_model.fit(
            x_train, y_train,
            validation_split=0.2,
            epochs=50,
            batch_size=best_params["batch_size"],
            verbose=1
        )

        # Evaluate the model
        y_pred = final_model.predict(x_test).flatten()
        metrics_to_compute = ["MSE", "MAE", "R²", "RMSE", "SMAPE"]
        metrics = evaluate_model(y_test, y_pred, metrics_to_compute)
        logging.info(f"Final model metrics: {metrics}")

        return best_params, final_model

    except optuna.exceptions.OptunaError as oe:
        logging.error(f"Optuna encountered an error: {str(oe)}")
        raise
    except ValueError as ve:
        logging.error(f"ValueError in optimize_lstm: {str(ve)}")
        raise
    except Exception as e:
        logging.error(f"Unexpected error during LSTM optimization: {str(e)}")
        raise

# -------------------- PREDICTION ON TRAINING DATA WITH LSTM -------------------- #
def predict_and_visualize_training(model, x_train, y_train, scaler, feature_column='Monthly_Total', sequence_length=12):
    """
    Predict on training data and visualize actual vs predicted values.

    Parameters:
        model (keras.Model): Trained LSTM model.
        x_train (np.ndarray): Training input data.
        y_train (np.ndarray): True target values for training data.
        scaler (MinMaxScaler): Scaler used for normalization.
        feature_column (str): Original feature for inverse scaling.
        sequence_length (int): Length of input sequences.

    Returns:
        pd.DataFrame: DataFrame with actual and predicted values.
    """
    try:
        # Generate predictions
        logging.info("Generating predictions on training data...")
        y_pred = model.predict(x_train).flatten()

        # Inverse transform predictions and actuals to original scale
        y_pred_inverse = scaler.inverse_transform(
            np.hstack([np.zeros((len(y_pred), x_train.shape[2] - 1)), y_pred.reshape(-1, 1)])
        )[:, -1]

        y_train_inverse = scaler.inverse_transform(
            np.hstack([np.zeros((len(y_train), x_train.shape[2] - 1)), y_train.reshape(-1, 1)])
        )[:, -1]

        # Prepare DataFrame for visualization
        training_results = pd.DataFrame({
            'Time Index': range(sequence_length, len(y_train) + sequence_length),
            'Actual': y_train_inverse,
            'Predicted': y_pred_inverse
        })

        # Plot Actual vs Predicted
        plt.figure(figsize=(16, 8))
        plt.plot(training_results['Time Index'], training_results['Actual'], label='Actual', color='blue')
        plt.plot(training_results['Time Index'], training_results['Predicted'], label='Predicted', color='orange', linestyle='--')
        plt.title("LSTM Model: Actual vs Predicted on Training Data")
        plt.xlabel("Time Index")
        plt.ylabel(f"Rainfall ({feature_column})")
        plt.legend()
        plt.grid()
        plt.tight_layout()
        plt.show()

        logging.info("Prediction and visualization completed successfully.")

        return training_results

    except Exception as e:
        logging.error(f"Error in predict_and_visualize_training: {str(e)}")
        raise

# -------------------- EXECUTION -------------------- #
try:
    # Call the prediction function
    training_results = predict_and_visualize_training(
        model=lstm_model, 
        x_train=x_train, 
        y_train=y_train, 
        scaler=scaler, 
        feature_column='Monthly_Total',
        sequence_length=sequence_length
    )

    # Save training results for further inspection
    training_results.to_csv("training_predictions.csv", index=False)
    logging.info("Training predictions saved to 'training_predictions.csv'.")

except Exception as e:
    logging.error(f"Failed to predict or visualize on training data: {str(e)}")


['1993-01-01 00:00:00', '1993-02-01 00:00:00', '1993-03-01 00:00:00',
 '1993-04-01 00:00:00', '1993-05-01 00:00:00', '1993-06-01 00:00:00',
 '1993-07-01 00:00:00', '1993-08-01 00:00:00', '1993-09-01 00:00:00',
 '1993-10-01 00:00:00',
 ...
 '2022-05-01 00:00:00', '2022-06-01 00:00:00', '2022-07-01 00:00:00',
 '2022-08-01 00:00:00', '2022-09-01 00:00:00', '2022-10-01 00:00:00',
 '2022-11-01 00:00:00', '2022-12-01 00:00:00', '2023-01-01 00:00:00',
 '2023-02-01 00:00:00']
Length: 360, dtype: datetime64[ns]
ERROR:root:Error converting YEAR and Month to datetime: Some values in 'YEAR' and 'Month' could not be converted to datetime.
ERROR:root:Unexpected error in handle_missing_values: YEAR or Month columns may still contain invalid data after cleaning.


              NAME     GH_ID   GEOGR2   GEOGR1  ELEVATION Element  YEAR  \
0  Addis Ababa Obs  SHADDI21  9.01891  38.7475       2386  PRECIP  1993   
1  Addis Ababa Obs  SHADDI21  9.01891  38.7475       2386  PRECIP  1993   
2  Addis Ababa Obs  SHADDI21  9.01891  38.7475       2386  PRECIP  1993   
3  Addis Ababa Obs  SHADDI21  9.01891  38.7475       2386  PRECIP  1993   
4  Addis Ababa Obs  SHADDI21  9.01891  38.7475       2386  PRECIP  1993   

   Month    1    2  ...   22    23   24    25   26    27   28   29   30   31  
0      1  0.0  0.0  ...  0.0   0.0  7.8   0.1  0.0   0.0  0.0  0.0  0.0  2.5  
1      2  0.0  0.0  ...  0.0   0.0  0.3   0.0  1.2   0.0  0.0  NaN  NaN  NaN  
2      3  0.0  0.0  ...  0.0   0.0  0.0   0.0  0.0   0.0  0.0  0.0  0.0  0.0  
3      4  0.0  0.0  ...  3.4   0.0  0.0   5.5  0.0  42.8  0.0  0.6  0.0  NaN  
4      5  0.0  0.0  ...  0.5  14.7  0.7  19.5  0.0   0.0  2.2  0.7  1.4  1.7  

[5 rows x 39 columns]


ValueError: YEAR or Month columns may still contain invalid data after cleaning.