In [1]:
from appgeopy import *
from my_packages import *

In [60]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

fld2savefig = "ApproachCompare_Figs"

if not os.path.exists(fld2savefig):
    os.makedirs(fld2savefig)

# Configuration
FILE_MAP = {
    "ARIMA_S1": "arima_output_seasonal1_combined_results.csv",
    "ARIMA_S3": "arima_output_seasonal3_combined_results.csv",
    "ARIMA_S5": "arima_output_seasonal5_combined_results.csv",
    "Approach_1": "figs2_Approach_1_combined_results.csv",
}

COLORS = {
    "ARIMA_S1": "#1f77b4",
    "ARIMA_S3": "#ff7f0e",
    "ARIMA_S5": "#d62728",
    "Approach_1": "#2ca02c",
}

# Load data
dfs = []
for name, file in FILE_MAP.items():
    if Path(file).exists():
        df = pd.read_csv(file)
        df["source"] = name
        df["layer"] = df["layer"].str.replace("_", " ")
        dfs.append(
            df[["station", "layer", "source", "r2", "rmse", "mae", "bias"]]
        )

all_data = pd.concat(dfs, ignore_index=True)

# Plot by layer
for layer in sorted(all_data["layer"].unique()[:]):
    layer_data = all_data[all_data["layer"] == layer]
    stations = sorted(layer_data["station"].unique())

    ncols = 3
    nrows = (len(stations) + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(16.5, 11.7), squeeze=False)
    axes = axes.flatten()

    for i, station in enumerate(stations):
        ax1 = axes[i]
        ax2 = ax1.twinx()

        station_data = (
            layer_data[layer_data["station"] == station]
            .set_index("source")
            .reindex(FILE_MAP.keys())
        )

        # Scale R² values for this station to [-1, 1]
        scaler = MinMaxScaler(feature_range=(-1, 1))
        r2_values = station_data["r2"].values.reshape(-1, 1)
        r2_scaled = scaler.fit_transform(r2_values).flatten()

        n_sources = len(FILE_MAP)
        width = 0.18
        x_r2 = np.array([0])
        x_errors = np.array([1.5, 2.5, 3.5])

        for j, (source, r2_val) in enumerate(zip(FILE_MAP.keys(), r2_scaled)):
            if source not in station_data.index:
                continue

            offset = (j - n_sources / 2 + 0.5) * width
            row = station_data.loc[source]

            # R² on left axis (scaled)
            ax1.bar(
                x_r2 + offset,
                r2_val,
                width,
                color=COLORS[source],
                label=source if i == 0 else "",
            )

            # Error metrics on right axis
            ax2.bar(
                x_errors + offset,
                [row["rmse"], row["mae"], row["bias"]],
                width,
                color=COLORS[source],
            )

        # Left axis (R²)
        ax1.set_ylim(-1.1, 2, 1)
        ax1.axhline(0, color="gray", linestyle="--", linewidth=0.8, alpha=0.5)
        ax1.spines["top"].set_visible(False)
        ax1.spines["right"].set_visible(False)
        ax1.tick_params(axis="y", labelsize=10)

        # Right axis (Errors)
        ax2.spines["top"].set_visible(False)
        ax2.spines["left"].set_visible(False)
        ax2.grid(axis="y", alpha=0.3, linestyle=":")
        ax2.tick_params(axis="y", labelsize=10)

        # ADD THESE LINES:
        y1_lim = ax1.get_ylim()
        y2_lim = ax2.get_ylim()
        # Calculate ratio to align zeros
        if y2_lim[0] < 0 < y2_lim[1]:
            ratio = abs(y1_lim[0]) / abs(y1_lim[1])  # Left axis ratio
            y2_range = y2_lim[1] - y2_lim[0]
            y2_bottom = -ratio * y2_range / (1 + ratio)
            y2_top = y2_range / (1 + ratio)
            ax2.set_ylim(y2_bottom, y2_top)

        # X-axis
        ax1.set_xticks([0, 1.5, 2.5, 3.5])
        ax1.set_xticklabels(["R²", "RMSE", "MAE", "PBIAS"], fontsize=10)
        ax1.set_xlim(-0.5, 4)

        # Station label
        ax1.text(
            0.98,
            0.95,
            station,
            transform=ax1.transAxes,
            ha="right",
            va="top",
            fontsize=11,
            fontweight="bold",
            bbox=dict(
                boxstyle="round,pad=0.3",
                fc="gold",
                ec="gray",
                alpha=0.5,
                lw=0.8,
            ),
        )

    # Hide unused
    for j in range(len(stations), len(axes)):
        axes[j].axis("off")

    # Figure-level axis labels
    fig.text(
        0.02,
        0.5,
        "R² (scaled)",
        va="center",
        rotation="vertical",
        fontsize=14,
        fontweight="bold",
    )
    fig.text(
        0.98,
        0.5,
        "Error (mm)",
        va="center",
        rotation="vertical",
        fontsize=14,
        fontweight="bold",
    )

    # Legend and layout
    handles, labels = axes[0].get_legend_handles_labels()

    # Modified legend call for bottom-right, 2x2 layout
    fig.legend(
        handles,
        labels,
        loc="lower right",  # Anchor the legend's lower right corner
        ncol=2,  # Set number of columns to 2 (results in 2x2 grid for 4 sources)
        bbox_to_anchor=(
            0.99,
            0.05,
        ),  # Place at the far right (1.01) and bottom (0.0) of the figure
        fontsize=16,
        frameon=False,
    )

    fig.suptitle(f"{layer}", fontsize=20, fontweight="bold", y=0.98)
    fig.tight_layout(rect=[0.04, 0.04, 0.96, 0.96], h_pad=2, w_pad=1.5)
    visualize.save_figure_with_exact_dimensions(
        fig=fig,
        savepath=os.path.join(fld2savefig, layer.replace(" ", "_")),
        width_px=4950,
        height_px=3510,
        dpi=300,
    )
    plt.close()