## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input data:

In [None]:
# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly = data_monthly[data_monthly['YEAR']
                            < 2025]  # Used elsewhere for validation

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly_Aug_ = data_monthly_Aug_[data_monthly_Aug_['YEAR']
                                      < 2025]  # Used elsewhere for validation

## Monthly distributions:

In [None]:
PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
                                        "distributed_MB_grids",
                                        "MBM/paper/LSTM_IS_ORIGINAL_Y_PAST")

# PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
#                                         "distributed_MB_grids",
#                                         "MBM/paper/LSTM_IS_NORM_Y_PAST")

PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, 'GLAMOS',
                                   'distributed_MB_grids', 'MBM/paper/NN')

PATH_PREDICTIONS_XGB = os.path.join(cfg.dataPath, 'GLAMOS',
                                    'distributed_MB_grids', 'MBM/paper/XGB')

hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
# fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
#     glacier_name="rhone",
#     year=2008,
#     path_pred_lstm=PATH_PREDICTIONS_NN,
#     apply_smoothing_fn=apply_gaussian_filter,
# )

# fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
#     glacier_name="rhone",
#     year=2008,
#     path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
#     apply_smoothing_fn=apply_gaussian_filter,
# )

In [None]:
CACHE_DIR = os.path.join(
    cfg.dataPath,
    "GLAMOS/distributed_MB_grids/MBM/paper/processed_dfs",
)
os.makedirs(CACHE_DIR, exist_ok=True)

REBUILD_CACHE = False  # set True only when inputs change

paths = {
    "LSTM": os.path.join(CACHE_DIR, "df_months_LSTM.parquet"),
    "NN": os.path.join(CACHE_DIR, "df_months_NN.parquet"),
    "XGB": os.path.join(CACHE_DIR, "df_months_XGB.parquet"),
    "GW": os.path.join(CACHE_DIR, "df_GLAMOS_w.parquet"),
    "GA": os.path.join(CACHE_DIR, "df_GLAMOS_a.parquet"),
}

if REBUILD_CACHE or not all(os.path.exists(p) for p in paths.values()):

    print("Building monthly prediction DataFrames...")

    df_months_LSTM = load_glwd_lstm_predictions(PATH_PREDICTIONS_LSTM_IS,
                                                hydro_months)
    df_months_NN = load_glwd_nn_predictions(PATH_PREDICTIONS_NN, hydro_months)
    df_months_XGB = load_glwd_nn_predictions(PATH_PREDICTIONS_XGB,
                                             hydro_months)

    PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                               "GLAMOS")
    glaciers = os.listdir(PATH_GLAMOS)

    glacier_years = (df_months_LSTM.groupby("glacier")["year"].unique().apply(
        sorted).to_dict())

    df_GLAMOS_w, df_GLAMOS_a = load_all_glamos(cfg, glacier_years, PATH_GLAMOS)

    # Save cache
    df_months_LSTM.to_parquet(paths["LSTM"])
    df_months_NN.to_parquet(paths["NN"])
    df_months_XGB.to_parquet(paths["XGB"])
    df_GLAMOS_w.to_parquet(paths["GW"])
    df_GLAMOS_a.to_parquet(paths["GA"])

    print(f"Cached DataFrames to {CACHE_DIR}")

else:
    print("Loading cached monthly prediction DataFrames...")

    df_months_LSTM = pd.read_parquet(paths["LSTM"])
    df_months_NN = pd.read_parquet(paths["NN"])
    df_months_XGB = pd.read_parquet(paths["XGB"])
    df_GLAMOS_w = pd.read_parquet(paths["GW"])
    df_GLAMOS_a = pd.read_parquet(paths["GA"])

    print(f"Loaded DataFrames from {CACHE_DIR}")
    

# Cut to years above 2000
# df_months_LSTM = df_months_LSTM[df_months_LSTM["year"] >= 2000]
# df_months_NN = df_months_NN[df_months_NN["year"] >= 2000]
# df_months_XGB = df_months_XGB[df_months_XGB["year"] >= 2000]

In [None]:
df_months_LSTM

#### Glacier-wide:

In [None]:
# --- 1. Glacier-wide annual mean MB per year ---
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_XGB = df_months_XGB.groupby(['glacier',
                                         'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

# --- 2. Compute the intersection of valid glacier–year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))
    & set(zip(glwd_months_XGB['glacier'], glwd_months_XGB['year']))
    & set(zip(glwd_months_LSTM['glacier'], glwd_months_LSTM['year']))
    & set(zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))
    & set(zip(glwd_months_GLAMOS_a['glacier'], glwd_months_GLAMOS_a['year'])))


# --- 3. Helper function for filtering by glacier–year pairs ---
def filter_to_valid(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- 4. Apply filtering to all datasets ---
glwd_months_NN_filtered = filter_to_valid(glwd_months_NN)
glwd_months_XGB_filtered = filter_to_valid(glwd_months_XGB)
glwd_months_LSTM_filtered = filter_to_valid(glwd_months_LSTM)
glwd_months_GLAMOS_filtered_w = filter_to_valid(glwd_months_GLAMOS_w)
glwd_months_GLAMOS_filtered_a = filter_to_valid(glwd_months_GLAMOS_a)

print(
    len(glwd_months_GLAMOS_filtered_w),
    len(glwd_months_GLAMOS_filtered_a),
    len(glwd_months_NN_filtered),
    len(glwd_months_XGB_filtered),
    len(glwd_months_LSTM_filtered),
)

# --- 5. Prepare for plotting ---
df_months_nn_long = prepare_monthly_long_df(
    glwd_months_LSTM_filtered,
    glwd_months_NN_filtered,
    glwd_months_XGB_filtered,
    glwd_months_GLAMOS_filtered_w,
    glwd_months_GLAMOS_filtered_a,
)

df_months_nn_long.head(2)

In [None]:
# ---- Compute min & max across models being plotted ----
min_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].min().min()
max_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].max().max()

# ---- Plot ----
fig = plot_monthly_joyplot(
    df_months_nn_long,
    x_range=(np.floor(min_), np.ceil(max_)),
    color_lstm=color_annual,  # or rename to your liking
    color_nn=color_winter,
    color_xgb="darkgreen",
    color_glamos="gray",
)

In [None]:
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_glamos'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_glamos']].max()
fig = plot_monthly_joyplot_single(df_months_nn_long,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name='MBM')
fig.savefig('figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')

### Elevation bands:

#### Highest:

In [None]:
bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Copy datasets
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# Assign elevation bands
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


def extract_highest_band(df, bin_width):
    max_elev = df.groupby("glacier")["elevation"].transform("max")
    highest_band = df[df["elevation"] >= (max_elev - bin_width)]
    return (highest_band.groupby(["glacier", "year"
                                  ]).mean(numeric_only=True).reset_index())


glwd_high_NN = extract_highest_band(df_months_NN_, bin)
glwd_high_XGB = extract_highest_band(df_months_XGB_, bin)
glwd_high_LSTM = extract_highest_band(df_months_LSTM_, bin)
glwd_high_GLAMOS_a = extract_highest_band(df_GLAMOS_a_, bin)
glwd_high_GLAMOS_w = extract_highest_band(df_GLAMOS_w_, bin)

valid_pairs = (
    set(zip(glwd_high_NN["glacier"], glwd_high_NN["year"]))
    & set(zip(glwd_high_XGB["glacier"], glwd_high_XGB["year"]))
    & set(zip(glwd_high_LSTM["glacier"], glwd_high_LSTM["year"]))
    & set(zip(glwd_high_GLAMOS_w["glacier"], glwd_high_GLAMOS_w["year"]))
    & set(zip(glwd_high_GLAMOS_a["glacier"], glwd_high_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


glwd_high_NN_filt = filter_to_valid(glwd_high_NN)
glwd_high_XGB_filt = filter_to_valid(glwd_high_XGB)
glwd_high_LSTM_filt = filter_to_valid(glwd_high_LSTM)
glwd_high_GLAMOS_a_filt = filter_to_valid(glwd_high_GLAMOS_a)
glwd_high_GLAMOS_w_filt = filter_to_valid(glwd_high_GLAMOS_w)

print(
    len(glwd_high_GLAMOS_w_filt),
    len(glwd_high_GLAMOS_a_filt),
    len(glwd_high_NN_filt),
    len(glwd_high_XGB_filt),
    len(glwd_high_LSTM_filt),
)

df_months_nn_long = prepare_monthly_long_df(
    glwd_high_LSTM_filt,
    glwd_high_NN_filt,
    glwd_high_XGB_filt,
    glwd_high_GLAMOS_w_filt,
    glwd_high_GLAMOS_a_filt,
)

min_, max_ = (
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].min().min(),
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].max().max(),
)

fig = plot_monthly_joyplot_single(
    df_months_nn_long,
    variable="mb_lstm",
    color_model=color_annual,
    x_range=(np.floor(min_), np.ceil(max_)),
    model_name="MBM",
)

fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_high_elv.png",
    dpi=300,
    bbox_inches="tight",
)

#### Lowest:

In [None]:
# =========================
# Lowest-elevation band
# =========================

bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# --- Copy to avoid modifying originals ---
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# --- Assign elevation bands ---
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


# --- Helper: extract lowest-elevation band per glacier ---
def extract_lowest_band(df, bin_width):
    min_elev = df.groupby("glacier")["elevation"].transform("min")
    lowest_band = df[df["elevation"] <= (min_elev + bin_width)]
    return (lowest_band.groupby(["glacier", "year"
                                 ]).mean(numeric_only=True).reset_index())


# --- Compute lowest-elevation bands ---
glwd_low_NN = extract_lowest_band(df_months_NN_, bin)
glwd_low_XGB = extract_lowest_band(df_months_XGB_, bin)
glwd_low_LSTM = extract_lowest_band(df_months_LSTM_, bin)
glwd_low_GLAMOS_a = extract_lowest_band(df_GLAMOS_a_, bin)
glwd_low_GLAMOS_w = extract_lowest_band(df_GLAMOS_w_, bin)

# --- Define common glacier–year pairs ---
valid_pairs = (
    set(zip(glwd_low_NN["glacier"], glwd_low_NN["year"]))
    & set(zip(glwd_low_XGB["glacier"], glwd_low_XGB["year"]))
    & set(zip(glwd_low_LSTM["glacier"], glwd_low_LSTM["year"]))
    & set(zip(glwd_low_GLAMOS_w["glacier"], glwd_low_GLAMOS_w["year"]))
    & set(zip(glwd_low_GLAMOS_a["glacier"], glwd_low_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


# --- Apply consistent filtering ---
glwd_low_NN_filt = filter_to_valid(glwd_low_NN)
glwd_low_XGB_filt = filter_to_valid(glwd_low_XGB)
glwd_low_LSTM_filt = filter_to_valid(glwd_low_LSTM)
glwd_low_GLAMOS_a_filt = filter_to_valid(glwd_low_GLAMOS_a)
glwd_low_GLAMOS_w_filt = filter_to_valid(glwd_low_GLAMOS_w)

print(
    len(glwd_low_GLAMOS_w_filt),
    len(glwd_low_GLAMOS_a_filt),
    len(glwd_low_NN_filt),
    len(glwd_low_XGB_filt),
    len(glwd_low_LSTM_filt),
)

# --- Prepare long-format dataframe for plotting ---
df_months_nn_long_low = prepare_monthly_long_df(
    glwd_low_LSTM_filt,
    glwd_low_NN_filt,
    glwd_low_XGB_filt,
    glwd_low_GLAMOS_w_filt,
    glwd_low_GLAMOS_a_filt,
)

# --- Determine x-axis limits ---
min_, max_ = (
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].min().min(),
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].max().max(),
)

# Optional manual clamp for ablation-dominated lowest band
min_ = -10

# --- Plot ---
fig = plot_monthly_joyplot_single(df_months_nn_long_low,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name="MBM",
                                  y_offset=0.15)

# --- Save figure ---
fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_low_elv.png",
    dpi=300,
    bbox_inches="tight",
)

## Feature importance:

#### Aggregated:

In [None]:
# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train = data_monthly[data_monthly.YEAR < 2025]

existing_glaciers = set(data_monthly_train.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly_train[data_monthly_train.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train['y'] = data_train['POINT_BALANCE']

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.YEAR < 2025]

existing_glaciers = set(data_monthly_train_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_train_Aug_[data_monthly_train_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

seed_all(cfg.seed)
ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=True,
                                       expect_target=True)
train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, PARAMS_LSTM_IS_past,
                                                   device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(PARAMS_LSTM_IS_past)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

state = torch.load(LSTM_IS_ORIGIN_Y_PAST, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm


@torch.no_grad()
def eval_rmse_pfi(model, device, dl, ds):
    """
    Compute RMSE_winter and RMSE_annual exactly like evaluate_with_preds:
    - uses ds.keys to route samples to winter vs annual
    - denormalizes with ds.y_std / ds.y_mean
    Assumes dl iterates ds in order (shuffle=False).
    """
    model.eval()

    y_true_w, y_pred_w = [], []
    y_true_a, y_pred_a = [], []

    all_keys = ds.keys
    i = 0

    y_std = ds.y_std.to(device)
    y_mean = ds.y_mean.to(device)

    for batch in dl:
        bs = batch["x_m"].shape[0]
        batch_keys = all_keys[i:i + bs]
        i += bs

        x_m = batch["x_m"].to(device)
        x_s = batch["x_s"].to(device)
        mv = batch["mv"].to(device)
        mw = batch["mw"].to(device)
        ma = batch["ma"].to(device)
        y = batch["y"].to(device)

        _, y_w, y_a = model(x_m, x_s, mv, mw, ma)

        # denormalize (matches evaluate_with_preds)
        y_true = y * y_std + y_mean
        y_w = y_w * y_std + y_mean
        y_a = y_a * y_std + y_mean

        for j in range(bs):
            *_, per = batch_keys[j]
            if per == "winter":
                y_true_w.append(float(y_true[j].cpu()))
                y_pred_w.append(float(y_w[j].cpu()))
            else:  # annual
                y_true_a.append(float(y_true[j].cpu()))
                y_pred_a.append(float(y_a[j].cpu()))

    def rmse(t, p):
        t = np.asarray(t, dtype=float)
        p = np.asarray(p, dtype=float)
        if len(t) == 0:
            return float("nan")
        return float(np.sqrt(np.mean((p - t)**2)))

    return rmse(y_true_w, y_pred_w), rmse(y_true_a, y_pred_a)


def permutation_importance_LSTM_MB_full(
    model,
    device,
    ds_test,
    monthly_cols,
    static_cols,
    n_repeats=5,
    seed=0,
    batch_size=128,
):
    """
    Aggregated PFI (period-aware, denormalized):
      - Monthly features: permute feature k across samples for ALL months ([:, :, k])
      - Static features:  permute feature j across samples ([:, j])

    Outputs both absolute ΔRMSE and relative ΔRMSE (Δ / baseline),
    plus a *sample-weighted* global relative metric:
        global_rel = (Nw * dw_rel + Na * da_rel) / (Nw + Na)
    """
    rng = np.random.default_rng(seed)

    print("\n▶️ Running aggregated permutation feature importance (PFI)...")

    base_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          shuffle=False)
    base_w, base_a = eval_rmse_pfi(model, device, base_dl, ds_test)

    # counts for sample-weighted global
    Nw = int(ds_test.iw.sum())
    Na = int(ds_test.ia.sum())

    print(
        f"[Baseline RMSE] winter={base_w:.3f} | annual={base_a:.3f} (Nw={Nw}, Na={Na})"
    )

    Xm0 = ds_test.Xm.clone()
    Xs0 = ds_test.Xs.clone()

    rows = []
    total_steps = (len(monthly_cols) + len(static_cols)) * n_repeats
    pbar = tqdm(total=total_steps, desc="Aggregated permutation importance")

    def _record(fname, ftype, dw, da):
        dw = np.asarray(dw, dtype=float)
        da = np.asarray(da, dtype=float)

        # relative deltas
        dw_rel = dw / base_w
        da_rel = da / base_a

        # sample-weighted global (relative)
        global_rel = (Nw * dw_rel + Na * da_rel) / max(Nw + Na, 1)

        rows.append(
            dict(
                feature=fname,
                type=ftype,
                baseline_winter=float(base_w),
                baseline_annual=float(base_a),
                mean_delta_winter=float(dw.mean()),
                std_delta_winter=float(dw.std(ddof=0)),
                mean_delta_annual=float(da.mean()),
                std_delta_annual=float(da.std(ddof=0)),
                mean_delta_winter_rel=float(dw_rel.mean()),
                std_delta_winter_rel=float(dw_rel.std(ddof=0)),
                mean_delta_annual_rel=float(da_rel.mean()),
                std_delta_annual_rel=float(da_rel.std(ddof=0)),
                mean_delta_global_rel=float(global_rel.mean()),
                std_delta_global_rel=float(global_rel.std(ddof=0)),
            ))

    # ---------- Monthly features ----------
    for k, fname in enumerate(monthly_cols):
        dw, da = [], []

        for _ in range(n_repeats):
            perm = rng.permutation(len(ds_test))
            ds_test.Xm[:, :, k] = Xm0[perm, :, k]

            dl = torch.utils.data.DataLoader(ds_test,
                                             batch_size=batch_size,
                                             shuffle=False)
            w, a = eval_rmse_pfi(model, device, dl, ds_test)

            dw.append(w - base_w)
            da.append(a - base_a)
            pbar.update(1)

        _record(fname, "monthly", dw, da)
        ds_test.Xm[:] = Xm0  # restore

    # ---------- Static features ----------
    for j, fname in enumerate(static_cols):
        dw, da = [], []

        for _ in range(n_repeats):
            perm = rng.permutation(len(ds_test))
            ds_test.Xs[:, j] = Xs0[perm, j]

            dl = torch.utils.data.DataLoader(ds_test,
                                             batch_size=batch_size,
                                             shuffle=False)
            w, a = eval_rmse_pfi(model, device, dl, ds_test)

            dw.append(w - base_w)
            da.append(a - base_a)
            pbar.update(1)

        _record(fname, "static", dw, da)
        ds_test.Xs[:] = Xs0  # restore

    pbar.close()

    out = pd.DataFrame(rows)

    # Useful default sorting: by global_rel, then annual_rel, then winter_rel
    out = out.sort_values(
        [
            "mean_delta_global_rel", "mean_delta_annual_rel",
            "mean_delta_winter_rel"
        ],
        ascending=False,
    ).reset_index(drop=True)

    return out

In [None]:
def plot_pfi_annual(df):
    d = df.sort_values("mean_delta_annual", ascending=False)

    plt.figure(figsize=(8, max(3, 0.35 * len(d))))
    plt.barh(d["feature"], d["mean_delta_annual"], xerr=d["std_delta_annual"])
    plt.gca().invert_yaxis()
    plt.title(
        f"Permutation Importance – Annual (baseline RMSE={d.baseline_annual.iloc[0]:.3f})"
    )
    plt.xlabel("Increase in RMSE_annual")
    plt.tight_layout()
    plt.show()


def plot_pfi_winter(df):
    d = df.sort_values("mean_delta_winter", ascending=False)

    plt.figure(figsize=(8, max(3, 0.35 * len(d))))
    plt.barh(d["feature"], d["mean_delta_winter"], xerr=d["std_delta_winter"])
    plt.gca().invert_yaxis()
    plt.title(
        f"Permutation Importance – Winter (baseline RMSE={d.baseline_winter.iloc[0]:.3f})"
    )
    plt.xlabel("Increase in RMSE_winter")
    plt.tight_layout()
    plt.show()


def plot_pfi_combined(df):
    d = df.sort_values("mean_delta_combined", ascending=False)

    plt.figure(figsize=(8, max(3, 0.35 * len(d))))
    plt.barh(d["feature"],
             d["mean_delta_combined"],
             xerr=d["std_delta_combined"])
    plt.gca().invert_yaxis()
    plt.title("Permutation Importance – Combined (winter + annual)")
    plt.xlabel("Weighted increase in RMSE")
    plt.tight_layout()
    plt.show()


pfi_full = permutation_importance_LSTM_MB_full(
    model=model,
    device=device,
    ds_test=ds_test_copy,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    n_repeats=8,
    seed=cfg.seed,
    w_winter=1.0,
    w_annual=1.0,
    batch_size=128,
)
plot_pfi_annual(pfi_full)
plot_pfi_winter(pfi_full)
plot_pfi_combined(pfi_full)

### Monthly:

In [None]:
# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train = data_monthly[data_monthly.YEAR < 2025]

existing_glaciers = set(data_monthly_train.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly_train[data_monthly_train.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train['y'] = data_train['POINT_BALANCE']

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.YEAR < 2025]

existing_glaciers = set(data_monthly_train_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_train_Aug_[data_monthly_train_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

seed_all(cfg.seed)
ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=False,
                                       expect_target=True)
train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, PARAMS_LSTM_IS_past,
                                                   device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(PARAMS_LSTM_IS_past)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

state = torch.load(LSTM_IS_ORIGIN_Y_PAST, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import numpy as np
import torch
import pandas as pd


@torch.no_grad()
def eval_rmse_pfi(model, device, dl, ds):
    model.eval()
    y_true_w, y_pred_w = [], []
    y_true_a, y_pred_a = [], []

    all_keys = ds.keys
    i = 0
    y_std = ds.y_std.to(device)
    y_mean = ds.y_mean.to(device)

    for batch in dl:
        bs = batch["x_m"].shape[0]
        batch_keys = all_keys[i:i + bs]
        i += bs

        x_m = batch["x_m"].to(device)
        x_s = batch["x_s"].to(device)
        mv = batch["mv"].to(device)
        mw = batch["mw"].to(device)
        ma = batch["ma"].to(device)
        y = batch["y"].to(device)

        _, y_w, y_a = model(x_m, x_s, mv, mw, ma)

        y_true = y * y_std + y_mean
        y_w = y_w * y_std + y_mean
        y_a = y_a * y_std + y_mean

        for j in range(bs):
            *_, per = batch_keys[j]
            if per == "winter":
                y_true_w.append(float(y_true[j].cpu()))
                y_pred_w.append(float(y_w[j].cpu()))
            else:
                y_true_a.append(float(y_true[j].cpu()))
                y_pred_a.append(float(y_a[j].cpu()))

    def rmse(t, p):
        if len(t) == 0:
            return float("nan")
        return float(np.sqrt(np.mean((np.asarray(p) - np.asarray(t))**2)))

    return rmse(y_true_w, y_pred_w), rmse(y_true_a, y_pred_a)


def _pfi_month_worker(task, model, device, ds_test, base_w, base_a, n_repeats,
                      seed):
    kind, k, t = task
    rng = np.random.default_rng(seed + k * 1000 + t)

    Xm0 = ds_test.Xm.clone()
    Xs0 = ds_test.Xs.clone()

    dw, da = [], []

    for _ in range(n_repeats):
        perm = rng.permutation(len(ds_test))

        if kind == "monthly":
            ds_test.Xm[:, t, k] = Xm0[perm, t, k]
        else:  # static
            ds_test.Xs[:, k] = Xs0[perm, k]

        dl = torch.utils.data.DataLoader(ds_test,
                                         batch_size=128,
                                         shuffle=False)
        w, a = eval_rmse_pfi(model, device, dl, ds_test)

        dw.append(w - base_w)
        da.append(a - base_a)

    return kind, k, t, np.mean(dw), np.std(dw), np.mean(da), np.std(da)


def permutation_importance_LSTM_MB_monthly_parallel(
    model,
    device,
    ds_test,
    monthly_cols,
    static_cols,
    month_names,
    n_repeats=5,
    n_jobs=12,
    seed=0,
):

    print("\n▶️ Running monthly permutation feature importance (PFI)...")

    base_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=128,
                                          shuffle=False)
    base_w, base_a = eval_rmse_pfi(model, device, base_dl, ds_test)

    print(f"[Baseline RMSE] winter={base_w:.3f} | annual={base_a:.3f}")

    tasks = []

    # Monthly
    for k in range(len(monthly_cols)):
        for t in range(len(month_names)):
            if ds_test.mv[:, t].sum() > 0:
                tasks.append(("monthly", k, t))

    # Static per month
    for j in range(len(static_cols)):
        for t in range(len(month_names)):
            if ds_test.mv[:, t].sum() > 0:
                tasks.append(("static", j, t))

    results = Parallel(n_jobs=n_jobs)(delayed(_pfi_month_worker)(
        task,
        model,
        device,
        mbm.data_processing.MBSequenceDataset._clone_for_permutation(ds_test),
        base_w,
        base_a,
        n_repeats,
        seed,
    ) for task in tqdm(tasks, desc="Monthly permutation importance"))

    rows = []
    for kind, k, t, mw, sw, ma, sa in results:

        fname = monthly_cols[k] if kind == "monthly" else static_cols[k]

        rows.append(
            dict(
                feature=fname,
                month=month_names[t],
                mean_delta_winter=mw,
                std_delta_winter=sw,
                mean_delta_annual=ma,
                std_delta_annual=sa,
                mean_delta_winter_rel=mw /
                base_w if np.isfinite(mw) else np.nan,
                mean_delta_annual_rel=ma /
                base_a if np.isfinite(ma) else np.nan,
                baseline_winter=base_w,
                baseline_annual=base_a,
            ))

    df = pd.DataFrame(rows)

    # ----- FIX GLOBAL RELATIVE IMPORTANCE -----
    Nw = int(ds_test.iw.sum())
    Na = int(ds_test.ia.sum())

    def compute_global(row):
        parts, weights = [], []
        if np.isfinite(row.mean_delta_winter_rel):
            parts.append(row.mean_delta_winter_rel)
            weights.append(Nw)
        if np.isfinite(row.mean_delta_annual_rel):
            parts.append(row.mean_delta_annual_rel)
            weights.append(Na)
        if len(parts) == 0:
            return np.nan
        return np.average(parts, weights=weights)

    df["mean_delta_global_rel"] = df.apply(compute_global, axis=1)

    return df

def permutation_importance_LSTM_MB_monthly_parallel_absolute(
    model,
    device,
    ds_test,
    monthly_cols,
    static_cols,
    month_names,
    n_repeats=5,
    n_jobs=12,
    seed=0,
):

    print("\n▶️ Running monthly permutation feature importance (PFI – absolute ΔRMSE)...")

    base_dl = torch.utils.data.DataLoader(ds_test, batch_size=128, shuffle=False)
    base_w, base_a = eval_rmse_pfi(model, device, base_dl, ds_test)

    print(f"[Baseline RMSE] winter={base_w:.3f} | annual={base_a:.3f}")

    tasks = []

    # Monthly
    for k in range(len(monthly_cols)):
        for t in range(len(month_names)):
            if ds_test.mv[:, t].sum() > 0:
                tasks.append(("monthly", k, t))

    # Static per month
    for j in range(len(static_cols)):
        for t in range(len(month_names)):
            if ds_test.mv[:, t].sum() > 0:
                tasks.append(("static", j, t))

    results = Parallel(n_jobs=n_jobs)(
        delayed(_pfi_month_worker)(
            task,
            model,
            device,
            mbm.data_processing.MBSequenceDataset._clone_for_permutation(ds_test),
            base_w,
            base_a,
            n_repeats,
            seed,
        )
        for task in tqdm(tasks, desc="Monthly permutation importance (absolute)")
    )

    rows = []
    for kind, k, t, mw, sw, ma, sa in results:
        fname = monthly_cols[k] if kind == "monthly" else static_cols[k]

        rows.append(dict(
            feature=fname,
            month=month_names[t],
            mean_delta_winter=mw,
            std_delta_winter=sw,
            mean_delta_annual=ma,
            std_delta_annual=sa,
            baseline_winter=base_w,
            baseline_annual=base_a,
        ))

    df = pd.DataFrame(rows)

    # ----- absolute sample-weighted global ΔRMSE -----
    Nw = int(ds_test.iw.sum())
    Na = int(ds_test.ia.sum())

    def compute_global_abs(row):
        parts, weights = [], []
        if np.isfinite(row.mean_delta_winter):
            parts.append(row.mean_delta_winter)
            weights.append(Nw)
        if np.isfinite(row.mean_delta_annual):
            parts.append(row.mean_delta_annual)
            weights.append(Na)
        if len(parts) == 0:
            return np.nan
        return np.average(parts, weights=weights)

    df["mean_delta_global"] = df.apply(compute_global_abs, axis=1)

    return df

In [None]:
month_names = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]

RUN_PFI_MONTHLY = False
if RUN_PFI_MONTHLY:

    pfi_monthly = permutation_importance_LSTM_MB_monthly_parallel_absolute(
        model=model,
        device=device,
        ds_test=ds_test_copy,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        month_names=month_names,
        n_repeats=8,
        n_jobs=12,
        seed=cfg.seed,
    )

    # save to cache
    pfi_monthly.to_csv('cache/pfi_LSTM_IS_monthly_absolute.csv', index=False)

else:
    pfi_monthly = pd.read_csv('cache/pfi_LSTM_IS_monthly_absolute.csv')

In [None]:
from scipy.ndimage import gaussian_filter1d

def plot_monthly_pfi_ridges(
    pfi_monthly,
    MONTHLY_COLS,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="global",          # "winter", "annual", "global"
    importance="relative",    # "relative" or "absolute"  <<< NEW
    drop_padded_months=True,
    fname=None,
    title=None,
):

    if importance == "relative":
        if metric == "winter":
            value_col = "mean_delta_winter_rel"
            label = "Relative ΔWinter RMSE"
        elif metric == "annual":
            value_col = "mean_delta_annual_rel"
            label = "Relative ΔAnnual RMSE"
        else:
            value_col = "mean_delta_global_rel"
            label = "Relative ΔGlobal RMSE"
        annot_fmt = "ΔRMSE_rel={:.3f}"
    else:  # absolute
        if metric == "winter":
            value_col = "mean_delta_winter"
            label = "ΔWinter RMSE"
        elif metric == "annual":
            value_col = "mean_delta_annual"
            label = "ΔAnnual RMSE"
        else:
            value_col = "mean_delta_global"
            label = "ΔGlobal RMSE"
        annot_fmt = "ΔRMSE={:.3f}"

    full_month_order = [
        "aug_","sep_","oct","nov","dec","jan","feb","mar","apr","may",
        "jun","jul","aug","sep","oct_"
    ]

    df = pfi_monthly.copy()
    df = df[df.feature.isin(MONTHLY_COLS)]
    df["feature_long"] = df["feature"].apply(lambda x: vois_climate_long_name.get(x, x))

    if drop_padded_months:
        padded = np.concatenate([months_tail_pad, months_head_pad])
        df = df[~df.month.isin(padded)]
        month_order = [m for m in full_month_order if m not in padded]
    else:
        month_order = [
        "sep_","oct","nov","dec","jan","feb","mar","apr","may",
        "jun","jul","aug","sep","oct_"
    ]

    df = df.groupby(["feature_long", "month"], as_index=False).mean(numeric_only=True)

    all_idx = pd.MultiIndex.from_product(
        [df.feature_long.unique(), month_order],
        names=["feature_long", "month"]
    )

    df = (
        df.set_index(["feature_long", "month"])
          .reindex(all_idx)
          .fillna(0.0)
          .reset_index()
    )

    piv = df.pivot(index="feature_long", columns="month", values=value_col)[month_order]

    feat_order = piv.mean(axis=1).sort_values(ascending=True).index
    piv = piv.loc[feat_order]

    piv_smooth = pd.DataFrame(
        np.vstack([gaussian_filter1d(piv.loc[f], sigma=1) for f in feat_order]),
        index=feat_order,
        columns=piv.columns,
    )

    if metric == "winter":
        winter_months = ["aug_","sep_","oct","nov","dec","jan","feb","mar","apr","may"]
        invalid = [m for m in piv_smooth.columns if m not in winter_months]
        piv_smooth[invalid] = 0.0

    fig, ax = plt.subplots(figsize=(10, 10))
    palette = sns.color_palette("magma", n_colors=len(feat_order))
    month_idx = np.arange(len(piv_smooth.columns))

    offset_step = np.nanmax(piv_smooth.values) * 0.7
    current_offset = 0.0
    max_importance = piv.max(axis=1)

    for feat, color in zip(feat_order, palette):
        y = piv_smooth.loc[feat].values

        ax.plot(month_idx, y + current_offset, color=color, lw=2)
        ax.fill_between(month_idx, current_offset, y + current_offset, color=color, alpha=0.4)

        ax.text(-0.6, current_offset, feat, va="center", ha="right", fontsize=13)

        max_idx = np.argmax(y)
        ax.text(
            month_idx[max_idx],
            y[max_idx] + current_offset + 0.05 * offset_step,
            annot_fmt.format(max_importance[feat]),
            ha="center", va="bottom", fontsize=11,
            bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=1.2)
        )

        current_offset += offset_step

    ax.set_yticks([])
    ax.set_xlim(0, len(month_idx) - 1)
    ax.set_xticks(month_idx)
    ax.set_xticklabels([m.strip("_").capitalize() for m in piv_smooth.columns], rotation=45, ha="right")
    ax.set_xlabel("Month")
    ax.set_title(title or f"Monthly Permutation Feature Importance – {label}")

    for spine in ["top", "right", "left"]:
        ax.spines[spine].set_visible(False)

    plt.tight_layout()
    if fname:
        fig.savefig(fname, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
padded = np.concatenate([months_tail_pad, months_head_pad])
padded

In [None]:
vars_to_plot = ['t2m',
 'tp',
 'slhf',
 'sshf',
 'ssrd',
 'fal',
 'str',
 'pcsr',]
plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="winter",
    drop_padded_months=True,
    importance="absolute",
    fname="figures/paper/CH_LSTM_monthly_PFI_winter.png",
    title="Monthly Permutation Feature Importance – Winter")

plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="annual",
    drop_padded_months=True,
    importance="absolute",
    fname="figures/paper/CH_LSTM_monthly_PFI_annual.png",
    title="Monthly Permutation Feature Importance – Annual")

plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="global",
    drop_padded_months=True,
    importance="absolute",
    fname="figures/paper/fig9_CH_LSTM_monthly_PFI_global.png",
    title="Monthly Permutation Feature Importance"
)

In [None]:
# import matplotlib.pyplot as plt
# import seaborn as sns

# # --- Month order ---
# month_order = [
#     "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
#     "jun", "jul", "aug", "sep", "oct_"
# ]
# # pfi_monthly = pfi_monthly[pfi_monthly.feature.isin(MONTHLY_COLS)]

# # --- Map features to long names ---
# pfi_monthly["feature_long"] = pfi_monthly["feature"].apply(
#     lambda x: vois_climate_long_name.get(x, x))
# pfi_monthly = pfi_monthly[~pfi_monthly.month.isin(
#     np.concatenate([months_tail_pad, months_head_pad]))]

# # --- Prepare pivot table for global ΔRMSE ---
# piv_global = pfi_monthly.pivot(index="feature_long",
#                                columns="month",
#                                values="mean_delta_global")

# # --- Reorder columns (months) ---
# piv_global = piv_global[[m for m in month_order if m in piv_global.columns]]

# # --- Order features by average global importance (optional, makes it clean) ---
# feat_order = (pfi_monthly.groupby("feature_long")
#               ["mean_delta_global"].mean().sort_values(ascending=False).index)
# piv_global = piv_global.loc[feat_order]

# # --- Plot single heatmap ---
# plt.figure(figsize=(10, 6))
# sns.heatmap(piv_global,
#             cmap="magma",
#             linewidths=0.3,
#             cbar_kws={"label": "ΔRMSE (global)"})
# plt.xlabel("Month")
# plt.ylabel("Feature")
# plt.title("Monthly Permutation Feature Importance – Global RMSE Δ")
# plt.tight_layout()
# plt.show()
