In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.optimize import curve_fit
from skgstat.models import exponential
from dateutil.relativedelta import relativedelta
from typing import Tuple

In [None]:
# =============================================================================
# 1. CONFIGURATION
# =============================================================================
class Config:
    """A centralized class for all user-configurable parameters."""
    
    # --- Input File Paths ---
    SHAPEFILE_PATH = Path(r"D:\1000_SCRIPTS\003_Project002\20250222_GTWR001\2_KrigingInterpolation\points_fld\mlcw_twd97.shp")
    MODEL_FOLDER = Path(r"D:\1000_SCRIPTS\003_Project002\20250222_GTWR001\5_GTWR_Prediction")
    
    # --- Output Folder ---
    OUTPUT_FOLDER = Path(r"Curve_Fitting_Results")
    
    # --- Analysis Settings ---
    KERNEL_NAME: str = "bisquare"
    START_DATE = pd.Timestamp(year=2016, month=5, day=1)
    
    # --- Plotting Style ---
    plt.style.use("seaborn-v0_8-talk")
    plt.rcParams["figure.figsize"] = (12, 5)

In [None]:

# =============================================================================
# 2. HELPER FUNCTIONS
# =============================================================================

def load_and_prepare_data(file_path: Path) -> pd.DataFrame:
    """Loads CSV data, creates a 'PointKey', and sets it as the index."""
    df = pd.read_csv(file_path)
    pointkey_arr = [f"X{int(x*1000)}Y{int(y*1000)}" for x, y in zip(df["X_TWD97"], df["Y_TWD97"])]
    df.insert(loc=0, column="PointKey", value=pointkey_arr)
    df.set_index("PointKey", inplace=True)
    return df

def get_time_series_for_point(df: pd.DataFrame, point_key: str, start_date: pd.Timestamp) -> Tuple[pd.Series, np.ndarray]:
    """Extracts and prepares the time series and cumulative days for a single point."""
    df_point = df.loc[point_key].copy()
    time_arr = [start_date + relativedelta(months=int(t)) for t in df_point["time_stamp"]]
    df_point["datetime"] = time_arr
    df_point = df_point.set_index("datetime").sort_index()
    coefficient_series = df_point["CUMDISP"].asfreq("MS")
    time_days = pd.Series(coefficient_series.index).diff().apply(lambda x: x.days).fillna(0).cumsum().values
    return coefficient_series, time_days

def fit_exponential_model(series: pd.Series, time_days: np.ndarray) -> Tuple[np.ndarray, pd.Series]:
    """Fits an exponential model using the original, correct logic."""
    valid_mask = ~np.isnan(series)
    y_data = series[valid_mask]
    x_data = time_days[valid_mask]

    if len(y_data) < 3:
        raise ValueError("Not enough valid data points to fit model.")

    parameter_bounds = (
        [np.min(x_data), np.min(y_data), np.min(y_data)],
        [np.max(x_data), np.max(y_data), np.max(y_data)],
    )
    initial_guess = [np.median(x_data), np.median(y_data), y_data.iloc[0]]

    params, _ = curve_fit(exponential, x_data, y_data, p0=initial_guess, bounds=parameter_bounds)
    predictions = pd.Series(data=exponential(time_days, *params), index=series.index)
    return params, predictions

def plot_and_save_figure(original_series: pd.Series, predicted_series: pd.Series, title: str, params: np.ndarray, output_path: Path):
    """
    Generates a plot with a parameter panel, saves it to a file, and closes it.
    """
    fig, ax = plt.subplots() # Use subplots to get access to the axes object
    
    original_series.plot(ax=ax, marker='o', linestyle='-', label='Original Data', markerfacecolor='none', color='gray', alpha=0.8)
    predicted_series.plot(ax=ax, label='Fitted Exponential Model', color='red', linestyle='--')
    
    ax.set_title(title)
    ax.set_xlabel("Date")
    ax.set_ylabel("Cumulative Displacement (CUMDISP)")
    ax.legend(loc='lower right') # Move legend to not overlap with text box
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)
    
    # --- ENHANCEMENT: Create and add the parameter panel ---
    # 1. Format the parameter values into a readable string.
    r, c0, b = params
    param_text = (
        f"Model Parameters\n"
        f"------------------\n"
        f"Range: {r:.2f}\n"
        f"Sill: {c0:.4f}\n"
        f"Nugget: {b:.4f}"
    )
    
    # 2. Define the appearance of the text box (the "panel").
    props = dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8)
    
    # 3. Place the text box in the top-left corner of the plot.
    # transform=ax.transAxes uses relative coordinates (0,0 is bottom-left, 1,1 is top-right).
    ax.text(0.03, 0.97, param_text, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props, fontfamily='monospace')
    # ---------------------------------------------------------
    
    fig.tight_layout()
    
    # Save the figure to the specified path with high resolution
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    # Close the plot to free up memory
    plt.close(fig)



In [None]:
# =============================================================================
# 3. MAIN BATCH PROCESSING WORKFLOW
# =============================================================================
def main():
    """Main function to run the batch processing workflow."""
    print("--- Starting Batch Analysis ---")

    Config.OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)
    print(f"Output figures will be saved to: {Config.OUTPUT_FOLDER}")

    output_files = sorted(list(Config.MODEL_FOLDER.glob(f"*{Config.KERNEL_NAME}*.csv")))
    if not output_files:
        raise FileNotFoundError(f"No files with kernel '{Config.KERNEL_NAME}' found in {Config.MODEL_FOLDER}")
    
    selected_file = output_files[0]
    df_processed = load_and_prepare_data(selected_file)
    mlcw_gdf = gpd.read_file(Config.SHAPEFILE_PATH)
    
    unique_point_keys = sorted(df_processed.index.unique())
    total_points = len(unique_point_keys)
    print(f"Found {total_points} unique points to process.\n")

    for i, point_key in enumerate(unique_point_keys):
        station_name = "Unknown" # Default name in case of error
        try:
            station_name = mlcw_gdf.query("PointKey == @point_key").STATION.values[0]
            print(f"[{i+1}/{total_points}] Processing Station: {station_name} ({point_key})")
            
            coefficient_series, time_days = get_time_series_for_point(df_processed, point_key, Config.START_DATE)
            
            fitted_params, predicted_values = fit_exponential_model(coefficient_series, time_days)
            
            plot_title = f"Exponential Model Fit for Station: {station_name}"
            safe_station_name = "".join(c for c in station_name if c.isalnum() or c in (' ', '_')).rstrip()
            output_filename = f"{safe_station_name}_{point_key}.png"
            output_path = Config.OUTPUT_FOLDER / output_filename
            
            # ** Pass the fitted_params to the plotting function **
            plot_and_save_figure(coefficient_series, predicted_values, plot_title, fitted_params, output_path)

        except Exception as e:
            print(f"  -> FAILED for station {station_name} ({point_key}). Reason: {e}")
            continue

    print("\n--- Batch Analysis Complete ---")

if __name__ == "__main__":
    main()
