In [2]:
from pathlib import Path
import pickle
import sys

import geopandas as gpd
from neuralhydrology.evaluation import get_tester
from neuralhydrology.utils.config import Config
import numpy as np
import pandas as pd

sys.path.append("../")
from src.readers.geom_reader import load_geodata
from src.timeseries_stats.metrics import evaluate_model
from src.utils.logger import setup_logger

LOG = setup_logger("fine_tune", log_file="../logs/fine_tuning.log")

In [3]:
# Load watershed geometries and gauge locations
ws, gauges = load_geodata(folder_depth="../")
common_index = gauges.index.to_list()
basemap_data = gpd.read_file("../data/geometry/basemap_2023.gpkg")
# Load cluster assignments (from Chapter 1)
# gauge_mapping = pd.read_csv(
#     "../res/chapter_one/gauge_hybrid_mapping.csv",
#     index_col="gauge_id",
#     dtype={"gauge_id": str},
# )

print(f"Loaded {len(gauges)} gauges with hybrid classification")


Loaded 996 gauges with hybrid classification


In [5]:
fine_tune_gauges = gpd.read_file("../res/FineTuneGauges.gpkg")[
    [
        "gauge_id",
        "name_ru",
        "name_en",
        "geometry",
        "lstm_nse_mswep",
        "lstm_nse_e5l",
        "lstm_nse_e5",
        "lstm_nse_gpcp",
    ]
]
fine_tune_gauges.set_index("gauge_id", inplace=True)

ft_index = fine_tune_gauges.index.tolist()
rest_gauges = gauges.loc[~gauges.index.isin(ft_index)]
rest_index = rest_gauges.index.tolist()

### Draw predictions vs observations for fine-tuned gauges (before)

In [54]:
LOG.info(
    "Initial parameters for fine-tuning gauges: ",
)
LOG.info("MSWEP NSE: %.2f", fine_tune_gauges["lstm_nse_mswep"].median())
LOG.info("E5L NSE: %.2f", fine_tune_gauges["lstm_nse_e5l"].median())
LOG.info("E5 NSE: %.2f", fine_tune_gauges["lstm_nse_e5"].median())
LOG.info("GPCP NSE: %.2f", fine_tune_gauges["lstm_nse_gpcp"].median())


[38;5;39m2025-12-12 16:11:34 | INFO     | PhDLogger | fine_tune | ℹ️  Initial parameters for fine-tuning gauges: [0m
[38;5;39m2025-12-12 16:11:34 | INFO     | PhDLogger | fine_tune | ℹ️  MSWEP NSE: 0.25[0m
[38;5;39m2025-12-12 16:11:34 | INFO     | PhDLogger | fine_tune | ℹ️  E5L NSE: 0.17[0m
[38;5;39m2025-12-12 16:11:34 | INFO     | PhDLogger | fine_tune | ℹ️  E5 NSE: 0.12[0m
[38;5;39m2025-12-12 16:11:34 | INFO     | PhDLogger | fine_tune | ℹ️  GPCP NSE: 0.16[0m


In [87]:
lstm_pathes = {
    "gpcp": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_gpcp_no_autocorr_static_1203_080402/test/model_epoch024/test_results.p",
    "e5": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_era5_no_autocorr_static_1203_220232/test/model_epoch020/test_results.p",
    "e5l": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_era5l_no_autocorr_static_1003_133332/test/model_epoch026/test_results.p",
    "mswep": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_mswep_no_autocorr_static_1103_191754/test/model_epoch024/test_results.p",
    "ealstm": "../data/lstm_configs/model_runs/ealstm_q_mm_day_mswep_no_static_0205_144958/test/model_epoch038/test_results.p",
}

cfg_pathes = {
    "gpcp": {
        "path": Path(
            "../data/lstm_configs/model_runs/cudalstm_q_mm_day_gpcp_no_autocorr_static_1203_080402/config.yml"
        ),
        "epoch": 24,
    },
    "mswep": {
        "path": Path(
            "../data/lstm_configs/model_runs/cudalstm_q_mm_day_mswep_no_autocorr_static_1103_191754/config.yml"
        ),
        "epoch": 24,
    },
}

with open("../data/models/fine_tune/poor_gauges.txt", "w") as the_file:
    for gauge_name in ft_index:
        the_file.write(f"{int(gauge_name)}\n")

In [88]:
lstm_cfg = cfg_pathes["mswep"]["path"]
epoch = cfg_pathes["mswep"]["epoch"]

cfg_run = Config(lstm_cfg)

cfg_run.update_config(
    {
        "train_basin_file": "../data/models/fine_tune/poor_gauges.txt",
        "validate_n_random_basins": len(ft_index),
        "validation_basin_file": "../data/models/fine_tune/poor_gauges.txt",
        "test_basin_file": "../data/models/fine_tune/poor_gauges.txt",
        "test_start_date": "01/01/2009",
        "test_end_date": "31/12/2020",
    }
)
tester = get_tester(cfg=cfg_run, run_dir=cfg_run.run_dir, period="test", init_model=True)
pred_results = tester.evaluate(epoch=epoch, save_results=True)

# Evaluation: 100%|██████████| 265/265 [01:02<00:00,  4.26it/s]


In [None]:
yearly_nse_data = []

print("Calculating yearly and period NSE for GPCP simulation results...")

if not pred_results:
    print(
        "No prediction results found. Please ensure the evaluation cell above ran successfully."
    )
else:
    for gauge_id, result in pred_results.items():
        # Extract data from NeuralHydrology result
        # Assuming structure matches previous cells: result["1D"]["xr"] is an xarray Dataset
        try:
            df = result["1D"]["xr"].to_dataframe().loc["2009":,]
        except KeyError:
            print(f"No data found for gauge {gauge_id} in the specified period.")
            continue

        # Handle multi-index if present (usually it's date/basin)
        if isinstance(df.index, pd.MultiIndex):
            df = df.droplevel(1)

        # Group by year
        for year, group in df.groupby(df.index.year):
            obs = group["q_mm_day_obs"]
            sim = group["q_mm_day_sim"]

            # Calculate NSE if we have enough data
            if len(obs.dropna()) > 10:
                try:
                    metrics = evaluate_model(observed=obs, simulated=sim)
                    nse = metrics["NSE"]
                except Exception:
                    nse = np.nan
            else:
                nse = np.nan

            yearly_nse_data.append({"gauge_id": gauge_id, "period": year, "nse": nse})

        # Calculate specific periods
        periods = {"2009-2018": df.loc["2009":"2018"], "2019-2020": df.loc["2019":"2020"]}

        for p_name, p_df in periods.items():
            nse = np.nan
            if len(p_df) > 0:
                obs = p_df["q_mm_day_obs"]
                sim = p_df["q_mm_day_sim"]
                if len(obs.dropna()) > 10:
                    try:
                        metrics = evaluate_model(observed=obs, simulated=sim)
                        nse = metrics["NSE"]
                    except Exception:
                        pass
            yearly_nse_data.append({"gauge_id": gauge_id, "period": p_name, "nse": nse})

    if not yearly_nse_data:
        print("No NSE data calculated.")
    else:
        # Create pivot table: Rows=Gauge, Cols=Period
        nse_df = pd.DataFrame(yearly_nse_data)

        # Ensure 'period' column exists even if empty (though yearly_nse_data check handles empty list)
        if "period" not in nse_df.columns:
            print("Error: 'period' column missing from data.")
        else:
            nse_pivot = nse_df.pivot(index="gauge_id", columns="period", values="nse")

            # Reorder columns: Years first, then periods
            year_cols = sorted([c for c in nse_pivot.columns if isinstance(c, int)])
            period_cols = ["2009-2018", "2019-2020"]
            # Ensure period columns exist (in case of missing data)
            period_cols = [c for c in period_cols if c in nse_pivot.columns]
            nse_pivot = nse_pivot[year_cols + period_cols]

            # Display with heatmap styling
            # vmin=-1 to handle poor performance without skewing the color scale too much for "okay" values
            styled_df = nse_pivot.style.background_gradient(
                cmap="RdYlGn", vmin=0, vmax=1
            ).format("{:.2f}")
            # display(styled_df)

In [105]:
nse_pivot[period_cols].median()

period
2009-2018    0.724910
2019-2020    0.251876
dtype: float64

### Rest gauges

In [9]:
lstm_pathes = {
    "gpcp": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_gpcp_no_autocorr_static_1203_080402/test/model_epoch024/test_results.p",
    "e5": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_era5_no_autocorr_static_1203_220232/test/model_epoch020/test_results.p",
    "e5l": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_era5l_no_autocorr_static_1003_133332/test/model_epoch026/test_results.p",
    "mswep": "../data/lstm_configs/model_runs/cudalstm_q_mm_day_mswep_no_autocorr_static_1103_191754/test/model_epoch024/test_results.p",
    "ealstm": "../data/lstm_configs/model_runs/ealstm_q_mm_day_mswep_no_static_0205_144958/test/model_epoch038/test_results.p",
}

cfg_pathes = {
    "gpcp": {
        "path": Path(
            "../data/lstm_configs/model_runs/cudalstm_q_mm_day_gpcp_no_autocorr_static_1203_080402/config.yml"
        ),
        "epoch": 24,
    },
    "mswep": {
        "path": Path(
            "../data/lstm_configs/model_runs/cudalstm_q_mm_day_mswep_no_autocorr_static_1103_191754/config.yml"
        ),
        "epoch": 24,
    },
}

with open("../data/models/fine_tune/rest_gauges.txt", "w") as the_file:
    for gauge_name in rest_index:
        the_file.write(f"{int(gauge_name)}\n")
lstm_cfg = cfg_pathes["mswep"]["path"]
epoch = cfg_pathes["mswep"]["epoch"]

cfg_run = Config(lstm_cfg)

cfg_run.update_config(
    {
        "train_basin_file": "../data/models/fine_tune/rest_gauges.txt",
        "validate_n_random_basins": len(ft_index),
        "validation_basin_file": "../data/models/fine_tune/rest_gauges.txt",
        "test_basin_file": "../data/models/fine_tune/rest_gauges.txt",
        "test_start_date": "01/01/2009",
        "test_end_date": "31/12/2020",
    }
)
tester = get_tester(cfg=cfg_run, run_dir=cfg_run.run_dir, period="test", init_model=True)
pred_results = tester.evaluate(epoch=epoch, save_results=True)

# Evaluation: 100%|██████████| 731/731 [02:43<00:00,  4.46it/s]


In [10]:
yearly_nse_data = []

print("Calculating yearly and period NSE for GPCP simulation results...")

if not pred_results:
    print(
        "No prediction results found. Please ensure the evaluation cell above ran successfully."
    )
else:
    for gauge_id, result in pred_results.items():
        # Extract data from NeuralHydrology result
        # Assuming structure matches previous cells: result["1D"]["xr"] is an xarray Dataset
        try:
            df = result["1D"]["xr"].to_dataframe().loc["2009":,]
        except KeyError:
            print(f"No data found for gauge {gauge_id} in the specified period.")
            continue

        # Handle multi-index if present (usually it's date/basin)
        if isinstance(df.index, pd.MultiIndex):
            df = df.droplevel(1)

        # Group by year
        for year, group in df.groupby(df.index.year):
            obs = group["q_mm_day_obs"]
            sim = group["q_mm_day_sim"]

            # Calculate NSE if we have enough data
            if len(obs.dropna()) > 10:
                try:
                    metrics = evaluate_model(observed=obs, simulated=sim)
                    nse = metrics["NSE"]
                except Exception:
                    nse = np.nan
            else:
                nse = np.nan

            yearly_nse_data.append({"gauge_id": gauge_id, "period": year, "nse": nse})

        # Calculate specific periods
        periods = {"2009-2018": df.loc["2009":"2018"], "2019-2020": df.loc["2019":"2020"]}

        for p_name, p_df in periods.items():
            nse = np.nan
            if len(p_df) > 0:
                obs = p_df["q_mm_day_obs"]
                sim = p_df["q_mm_day_sim"]
                if len(obs.dropna()) > 10:
                    try:
                        metrics = evaluate_model(observed=obs, simulated=sim)
                        nse = metrics["NSE"]
                    except Exception:
                        pass
            yearly_nse_data.append({"gauge_id": gauge_id, "period": p_name, "nse": nse})

    if not yearly_nse_data:
        print("No NSE data calculated.")
    else:
        # Create pivot table: Rows=Gauge, Cols=Period
        nse_df = pd.DataFrame(yearly_nse_data)

        # Ensure 'period' column exists even if empty (though yearly_nse_data check handles empty list)
        if "period" not in nse_df.columns:
            print("Error: 'period' column missing from data.")
        else:
            nse_pivot = nse_df.pivot(index="gauge_id", columns="period", values="nse")

            # Reorder columns: Years first, then periods
            year_cols = sorted([c for c in nse_pivot.columns if isinstance(c, int)])
            period_cols = ["2009-2018", "2019-2020"]
            # Ensure period columns exist (in case of missing data)
            period_cols = [c for c in period_cols if c in nse_pivot.columns]
            nse_pivot = nse_pivot[year_cols + period_cols]

            # Display with heatmap styling
            # vmin=-1 to handle poor performance without skewing the color scale too much for "okay" values
            styled_df = nse_pivot.style.background_gradient(
                cmap="RdYlGn", vmin=0, vmax=1
            ).format("{:.2f}")
            # display(styled_df)


Calculating yearly and period NSE for GPCP simulation results...


In [11]:
nse_pivot[period_cols].median()


period
2009-2018    0.837036
2019-2020    0.708212
dtype: float64

### Plots

In [None]:
# Process LSTM pickle files and create dataframes with comprehensive metrics
lstm_dataset_dfs = {}

for _dataset_name, _pickle_path in lstm_pathes.items():
    print(f"Processing LSTM {_dataset_name}...")

    with open(_pickle_path, "rb") as _f:
        _lstm_data = pickle.load(_f)

    lstm_dataset_dfs[_dataset_name] = {}

    for _gauge_id, _gauge_results in _lstm_data.items():
        if _gauge_id in ft_index:
            # Extract observed and simulated data
            _df_data = (
                _gauge_results["1D"]["xr"]
                .to_dataframe()
                .droplevel(1)
                .rename(columns={"q_mm_day_obs": "obs", "q_mm_day_sim": "sim"})
            )
            lstm_dataset_dfs[_dataset_name][_gauge_id] = _df_data
        else:
            continue

In [38]:
img_dir = Path("../data/images/series_before_finetuning")
img_dir.mkdir(parents=True, exist_ok=True)

In [56]:
import matplotlib.pyplot as plt

# Create output directory
img_dir = Path("../data/images/series_before_finetuning")
img_dir.mkdir(parents=True, exist_ok=True)

# Get list of gauges from the first dataset (assuming all have same gauges)
first_dataset = next(iter(lstm_dataset_dfs.values()))
gauge_ids = first_dataset.keys()

for gauge_id in gauge_ids:
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot observed data (only once)
    # We can take observed from any dataset as it should be the same
    obs_data = lstm_dataset_dfs["mswep"][gauge_id]["obs"]
    ax.plot(
        obs_data.index,
        obs_data,
        label="Observed",
        color="black",
        linewidth=1.5,
        alpha=0.7,
    )

    title_parts = [f"Gauge {gauge_id}"]

    # Plot simulations for each dataset
    colors = {
        "gpcp": "red",
        "mswep": "blue",
        "e5l": "green",
        "e5": "orange",
        "ealstm": "purple",
    }

    for dataset_name, dataset_dfs in lstm_dataset_dfs.items():
        if gauge_id in dataset_dfs:
            sim_data = dataset_dfs[gauge_id]["sim"]

            # Calculate metrics
            metrics = evaluate_model(observed=obs_data, simulated=sim_data)

            # Add to plot
            ax.plot(
                sim_data.index,
                sim_data,
                label=f"{dataset_name} (NSE: {metrics['NSE']:.2f})",
                color=colors.get(dataset_name, "gray"),
                linewidth=1,
                alpha=0.6,
            )

            # Add metrics to title (simplified)
            title_parts.append(
                f"{dataset_name}: NSE={metrics['NSE']:.2f}, KGE={metrics['KGE']:.2f}"
            )

    ax.set_xlabel("Date")
    ax.set_ylabel("Discharge (mm/day)")
    ax.set_title("\n".join([title_parts[0], " | ".join(title_parts[1:])]), fontsize=10)
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Save figure
    fig.tight_layout()
    fig.savefig(img_dir / f"gauge_{gauge_id}_comparison.png", dpi=150)
    plt.close(fig)

### Read pre-trained configs

### Fine-tune for poor-performing gauges