## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import joypy
from pandas.api.types import CategoricalDtype

import massbalancemachine as mbm

# Add root of repo to import MBM!
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly = data_monthly[data_monthly['YEAR']
                            < 2025]  # Used elsewhere for validation

## Monthly distributions:

In [None]:
PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
                                        "distributed_MB_grids",
                                        "MBM/paper/LSTM_IS_original_y")

PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, 'GLAMOS',
                                   'distributed_MB_grids', 'MBM/paper/NN')

PATH_PREDICTIONS_XGB = os.path.join(cfg.dataPath, 'GLAMOS',
                                    'distributed_MB_grids', 'MBM/paper/XGB')

hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
# fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
#     glacier_name="rhone",
#     year=2008,
#     path_pred_lstm=PATH_PREDICTIONS_NN,
#     apply_smoothing_fn=apply_gaussian_filter,
# )

# fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
#     glacier_name="rhone",
#     year=2008,
#     path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
#     apply_smoothing_fn=apply_gaussian_filter,
# )

In [None]:
CACHE_DIR = os.path.join(
    cfg.dataPath,
    "GLAMOS/distributed_MB_grids/MBM/paper/processed_dfs",
)
os.makedirs(CACHE_DIR, exist_ok=True)

REBUILD_CACHE = True  # set True only when inputs change

paths = {
    "LSTM": os.path.join(CACHE_DIR, "df_months_LSTM_.parquet"),
    "NN": os.path.join(CACHE_DIR, "df_months_NN_.parquet"),
    "XGB": os.path.join(CACHE_DIR, "df_months_XGB_.parquet"),
    "GW": os.path.join(CACHE_DIR, "df_GLAMOS_w_.parquet"),
    "GA": os.path.join(CACHE_DIR, "df_GLAMOS_a_.parquet"),
}

if REBUILD_CACHE or not all(os.path.exists(p) for p in paths.values()):

    print("Building monthly prediction DataFrames...")

    df_months_LSTM = load_glwd_lstm_predictions(PATH_PREDICTIONS_LSTM_IS,
                                                hydro_months)
    df_months_NN = load_glwd_nn_predictions(PATH_PREDICTIONS_NN, hydro_months)
    df_months_XGB = load_glwd_nn_predictions(PATH_PREDICTIONS_XGB,
                                             hydro_months)

    PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                               "GLAMOS")
    glaciers = os.listdir(PATH_GLAMOS)

    glacier_years = (df_months_LSTM.groupby("glacier")["year"].unique().apply(
        sorted).to_dict())

    df_GLAMOS_w, df_GLAMOS_a = load_all_glamos(cfg, glacier_years, PATH_GLAMOS)

    # Save cache
    df_months_LSTM.to_parquet(paths["LSTM"])
    df_months_NN.to_parquet(paths["NN"])
    df_months_XGB.to_parquet(paths["XGB"])
    df_GLAMOS_w.to_parquet(paths["GW"])
    df_GLAMOS_a.to_parquet(paths["GA"])

    print(f"Cached DataFrames to {CACHE_DIR}")

else:
    print("Loading cached monthly prediction DataFrames...")

    df_months_LSTM = pd.read_parquet(paths["LSTM"])
    df_months_NN = pd.read_parquet(paths["NN"])
    df_months_XGB = pd.read_parquet(paths["XGB"])
    df_GLAMOS_w = pd.read_parquet(paths["GW"])
    df_GLAMOS_a = pd.read_parquet(paths["GA"])

    print(f"Loaded DataFrames from {CACHE_DIR}")

#### Glacier-wide:

In [None]:
# --- 1. Glacier-wide annual mean MB per year ---
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_XGB = df_months_XGB.groupby(['glacier',
                                         'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

# --- 2. Compute the intersection of valid glacier–year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))
    & set(zip(glwd_months_XGB['glacier'], glwd_months_XGB['year']))
    & set(zip(glwd_months_LSTM['glacier'], glwd_months_LSTM['year']))
    & set(zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))
    & set(zip(glwd_months_GLAMOS_a['glacier'], glwd_months_GLAMOS_a['year'])))


# --- 3. Helper function for filtering by glacier–year pairs ---
def filter_to_valid(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- 4. Apply filtering to all datasets ---
glwd_months_NN_filtered = filter_to_valid(glwd_months_NN)
glwd_months_XGB_filtered = filter_to_valid(glwd_months_XGB)
glwd_months_LSTM_filtered = filter_to_valid(glwd_months_LSTM)
glwd_months_GLAMOS_filtered_w = filter_to_valid(glwd_months_GLAMOS_w)
glwd_months_GLAMOS_filtered_a = filter_to_valid(glwd_months_GLAMOS_a)

print(
    len(glwd_months_GLAMOS_filtered_w),
    len(glwd_months_GLAMOS_filtered_a),
    len(glwd_months_NN_filtered),
    len(glwd_months_XGB_filtered),
    len(glwd_months_LSTM_filtered),
)

# --- 5. Prepare for plotting ---
df_months_nn_long = prepare_monthly_long_df(
    glwd_months_LSTM_filtered,
    glwd_months_NN_filtered,
    glwd_months_XGB_filtered,
    glwd_months_GLAMOS_filtered_w,
    glwd_months_GLAMOS_filtered_a,
)

df_months_nn_long.head(2)

In [None]:
# ---- Compute min & max across models being plotted ----
min_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].min().min()
max_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].max().max()

# ---- Plot ----
fig = plot_monthly_joyplot(
    df_months_nn_long,
    x_range=(np.floor(min_), np.ceil(max_)),
    color_lstm=color_annual,  # or rename to your liking
    color_nn=color_winter,
    color_xgb="darkgreen",
    color_glamos="gray",
)

In [None]:
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_glamos'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_glamos']].max()
fig = plot_monthly_joyplot_single(df_months_nn_long,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name='MBM')
fig.savefig('figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')

### Elevation bands:

#### Highest:

In [None]:
bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Copy datasets
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# Assign elevation bands
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


def extract_highest_band(df, bin_width):
    max_elev = df.groupby("glacier")["elevation"].transform("max")
    highest_band = df[df["elevation"] >= (max_elev - bin_width)]
    return (highest_band.groupby(["glacier", "year"
                                  ]).mean(numeric_only=True).reset_index())


glwd_high_NN = extract_highest_band(df_months_NN_, bin)
glwd_high_XGB = extract_highest_band(df_months_XGB_, bin)
glwd_high_LSTM = extract_highest_band(df_months_LSTM_, bin)
glwd_high_GLAMOS_a = extract_highest_band(df_GLAMOS_a_, bin)
glwd_high_GLAMOS_w = extract_highest_band(df_GLAMOS_w_, bin)

valid_pairs = (
    set(zip(glwd_high_NN["glacier"], glwd_high_NN["year"]))
    & set(zip(glwd_high_XGB["glacier"], glwd_high_XGB["year"]))
    & set(zip(glwd_high_LSTM["glacier"], glwd_high_LSTM["year"]))
    & set(zip(glwd_high_GLAMOS_w["glacier"], glwd_high_GLAMOS_w["year"]))
    & set(zip(glwd_high_GLAMOS_a["glacier"], glwd_high_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


glwd_high_NN_filt = filter_to_valid(glwd_high_NN)
glwd_high_XGB_filt = filter_to_valid(glwd_high_XGB)
glwd_high_LSTM_filt = filter_to_valid(glwd_high_LSTM)
glwd_high_GLAMOS_a_filt = filter_to_valid(glwd_high_GLAMOS_a)
glwd_high_GLAMOS_w_filt = filter_to_valid(glwd_high_GLAMOS_w)

print(
    len(glwd_high_GLAMOS_w_filt),
    len(glwd_high_GLAMOS_a_filt),
    len(glwd_high_NN_filt),
    len(glwd_high_XGB_filt),
    len(glwd_high_LSTM_filt),
)

df_months_nn_long = prepare_monthly_long_df(
    glwd_high_LSTM_filt,
    glwd_high_NN_filt,
    glwd_high_XGB_filt,
    glwd_high_GLAMOS_w_filt,
    glwd_high_GLAMOS_a_filt,
)

min_, max_ = (
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].min().min(),
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].max().max(),
)

fig = plot_monthly_joyplot_single(
    df_months_nn_long,
    variable="mb_lstm",
    color_model=color_annual,
    x_range=(np.floor(min_), np.ceil(max_)),
    model_name="MBM",
)

fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_high_elv.png",
    dpi=300,
    bbox_inches="tight",
)

#### Lowest:

In [None]:
# =========================
# Lowest-elevation band
# =========================

bin = 250
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# --- Copy to avoid modifying originals ---
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# --- Assign elevation bands ---
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


# --- Helper: extract lowest-elevation band per glacier ---
def extract_lowest_band(df, bin_width):
    min_elev = df.groupby("glacier")["elevation"].transform("min")
    lowest_band = df[df["elevation"] <= (min_elev + bin_width)]
    return (lowest_band.groupby(["glacier", "year"
                                 ]).mean(numeric_only=True).reset_index())


# --- Compute lowest-elevation bands ---
glwd_low_NN = extract_lowest_band(df_months_NN_, bin)
glwd_low_XGB = extract_lowest_band(df_months_XGB_, bin)
glwd_low_LSTM = extract_lowest_band(df_months_LSTM_, bin)
glwd_low_GLAMOS_a = extract_lowest_band(df_GLAMOS_a_, bin)
glwd_low_GLAMOS_w = extract_lowest_band(df_GLAMOS_w_, bin)

# --- Define common glacier–year pairs ---
valid_pairs = (
    set(zip(glwd_low_NN["glacier"], glwd_low_NN["year"]))
    & set(zip(glwd_low_XGB["glacier"], glwd_low_XGB["year"]))
    & set(zip(glwd_low_LSTM["glacier"], glwd_low_LSTM["year"]))
    & set(zip(glwd_low_GLAMOS_w["glacier"], glwd_low_GLAMOS_w["year"]))
    & set(zip(glwd_low_GLAMOS_a["glacier"], glwd_low_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


# --- Apply consistent filtering ---
glwd_low_NN_filt = filter_to_valid(glwd_low_NN)
glwd_low_XGB_filt = filter_to_valid(glwd_low_XGB)
glwd_low_LSTM_filt = filter_to_valid(glwd_low_LSTM)
glwd_low_GLAMOS_a_filt = filter_to_valid(glwd_low_GLAMOS_a)
glwd_low_GLAMOS_w_filt = filter_to_valid(glwd_low_GLAMOS_w)

print(
    len(glwd_low_GLAMOS_w_filt),
    len(glwd_low_GLAMOS_a_filt),
    len(glwd_low_NN_filt),
    len(glwd_low_XGB_filt),
    len(glwd_low_LSTM_filt),
)

# --- Prepare long-format dataframe for plotting ---
df_months_nn_long_low = prepare_monthly_long_df(
    glwd_low_LSTM_filt,
    glwd_low_NN_filt,
    glwd_low_XGB_filt,
    glwd_low_GLAMOS_w_filt,
    glwd_low_GLAMOS_a_filt,
)

# --- Determine x-axis limits ---
min_, max_ = (
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].min().min(),
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].max().max(),
)

# Optional manual clamp for ablation-dominated lowest band
min_ = -10

# --- Plot ---
fig = plot_monthly_joyplot_single(df_months_nn_long_low,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name="MBM",
                                  y_offset=0.15)

# --- Save figure ---
fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_low_elv.png",
    dpi=300,
    bbox_inches="tight",
)

## Feature importance:

Feature importance is done IN-SAMPLE.

In [None]:
existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))
# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    normalize_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, PARAMS_LSTM_IS, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(PARAMS_LSTM_IS)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# Load and evaluate on test
state = torch.load(LSTM_IS_NORM_Y, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

#### Aggregated:

In [None]:
# from scripts.PFI_all import permutation_feature_importance_mbm_parallel

# RUN_PFI = True  # Set False to load existing results

# # Define save path
# save_dir = os.path.join(
#     cfg.dataPath,
#     "GLAMOS/distributed_MB_grids/MBM/testing_LSTM/processed_dfs",
# )
# os.makedirs(save_dir, exist_ok=True)
# pfi_path = os.path.join(save_dir, "pfi_parallel.parquet")

# if RUN_PFI:
#     # --- Compute permutation feature importance ---
#     print("▶️ Running permutation feature importance (PFI)...")
#     pfi_parallel = permutation_feature_importance_mbm_parallel(
#         cfg=cfg,
#         custom_params=PARAMS_LSTM_IS,
#         model_filename=LSTM_IS_NORM_Y,
#         df_eval=
#         df_train,  # evaluation DataFrame WITH TARGETS aligned to predictions
#         MONTHLY_COLS=MONTHLY_COLS,
#         STATIC_COLS=STATIC_COLS,
#         ds_train=ds_train,
#         train_idx=train_idx,
#         target_col="POINT_BALANCE",  # <-- target column name
#         months_head_pad=months_head_pad,
#         months_tail_pad=months_tail_pad,
#         seed=cfg.seed,
#         n_repeats=5,
#         batch_size=256,
#         max_workers=None,  # auto: n_cpus-1 (cap 32)
#     )

#     # Rename features to readable names
#     pfi_parallel["feature"] = pfi_parallel["feature"].apply(
#         lambda x: vois_climate_long_name.get(x, x))

#     # Save
#     pfi_parallel.to_parquet(pfi_path)
#     print(f"PFI results saved to {pfi_path}")

# else:
#     # --- Load previously saved results ---
#     pfi_parallel = pd.read_parquet(pfi_path)
#     print(f"PFI results loaded from {pfi_path}")

# # --- Plot ---
# plt.figure(figsize=(8, max(3, 0.35 * len(pfi_parallel))))
# plt.barh(
#     pfi_parallel["feature"],
#     pfi_parallel["mean_delta"],
#     xerr=pfi_parallel["std_delta"],
# )
# plt.gca().invert_yaxis()
# plt.title(
#     f"Permutation Feature Importance (Δ{pfi_parallel['metric_name'].iloc[0]}; "
#     f"baseline={pfi_parallel['baseline'].iloc[0]:.3f})")
# plt.xlabel(
#     f"Increase in {pfi_parallel['metric_name'].iloc[0]} (higher = more important)"
# )
# plt.tight_layout()
# plt.show()

### Monthly:

In [None]:
from scripts.PFI_monthly import permutation_feature_importance_mbm_monthly_parallel

# --- Define monthly mapping ---
month_map = {
    "aug_": 0,
    "sep_": 1,
    "oct": 2,
    "nov": 3,
    "dec": 4,
    "jan": 5,
    "feb": 6,
    "mar": 7,
    "apr": 8,
    "may": 9,
    "jun": 10,
    "jul": 11,
    "aug": 12,
    "sep": 13,
    "oct_": 14,
}

# --- Prepare evaluation DataFrame ---
# df_eval = pd.concat([df_train.copy(), df_train.copy()],
#                     axis=0).reset_index(drop=True)
df_eval = df_train.copy()
df_eval["MONTH_IDX"] = df_eval["MONTHS"].str.lower().map(month_map)

# --- Run or load ---
RUN_PFI_MONTHLY = True  # set to False to just load saved results

save_dir = os.path.join(
    cfg.dataPath,
    "GLAMOS/distributed_MB_grids/MBM/testing_LSTM/processed_dfs",
)
os.makedirs(save_dir, exist_ok=True)
pfi_monthly_path = os.path.join(save_dir, "pfi_monthly.parquet")

if RUN_PFI_MONTHLY:
    print("▶️ Running monthly permutation feature importance (PFI)...")

    pfi_monthly = permutation_feature_importance_mbm_monthly_parallel(
        cfg,
        PARAMS_LSTM_IS,
        LSTM_IS_NORM_Y,
        df_eval,
        MONTHLY_COLS,
        STATIC_COLS,
        ds_train,
        train_idx,
        months_head_pad,
        months_tail_pad,
        seed=cfg.seed,
        n_repeats=3,
        batch_size=256,
        denorm=True,
        max_workers=None,
    )

    # Save results
    pfi_monthly.to_parquet(pfi_monthly_path)
    print(f"Monthly PFI results saved to {pfi_monthly_path}")

else:
    # --- Load previously saved results ---
    pfi_monthly = pd.read_parquet(pfi_monthly_path)
    print(f"Monthly PFI results loaded from {pfi_monthly_path}")

# --- Optional: Quick preview ---
print(pfi_monthly.head())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# --- Month order ---
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]

# --- Map features to long names ---
pfi_monthly["feature_long"] = pfi_monthly["feature"].apply(
    lambda x: vois_climate_long_name.get(x, x))
pfi_monthly = pfi_monthly[~pfi_monthly.month.isin(
    np.concatenate([months_tail_pad, months_head_pad]))]

# --- Prepare pivot table for global ΔRMSE ---
piv_global = pfi_monthly.pivot(index="feature_long",
                               columns="month",
                               values="mean_delta_global")

# --- Reorder columns (months) ---
piv_global = piv_global[[m for m in month_order if m in piv_global.columns]]

# --- Order features by average global importance (optional, makes it clean) ---
feat_order = (pfi_monthly.groupby("feature_long")
              ["mean_delta_global"].mean().sort_values(ascending=False).index)
piv_global = piv_global.loc[feat_order]

# --- Plot single heatmap ---
plt.figure(figsize=(10, 6))
sns.heatmap(piv_global,
            cmap="magma",
            linewidths=0.3,
            cbar_kws={"label": "ΔRMSE (global)"})
plt.xlabel("Month")
plt.ylabel("Feature")
plt.title("Monthly Permutation Feature Importance – Global RMSE Δ")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.ndimage import gaussian_filter1d

# --- Month order ---
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]

# --- Map features to long names ---
pfi_monthly["feature_long"] = pfi_monthly["feature"].apply(
    lambda x: vois_climate_long_name.get(x, x))
pfi_monthly = pfi_monthly[~pfi_monthly.month.isin(
    np.concatenate([months_tail_pad, months_head_pad]))]

# --- Pivot table for global ΔRMSE ---
piv_global = pfi_monthly.pivot(index="feature_long",
                               columns="month",
                               values="mean_delta_global")

# --- Reorder months ---
piv_global = piv_global[[m for m in month_order if m in piv_global.columns]]

# --- Order features by average importance (ascending = top = most important visually) ---
feat_order = piv_global.mean(axis=1).sort_values(ascending=True).index

# --- Smooth for nicer ridges ---
piv_smooth = pd.DataFrame(
    np.vstack(
        [gaussian_filter1d(piv_global.loc[f], sigma=1) for f in feat_order]),
    index=feat_order,
    columns=piv_global.columns,
)

# --- Plot setup ---
fig, ax = plt.subplots(figsize=(10, 10))
palette = sns.color_palette("magma", n_colors=len(feat_order))
month_idx = np.arange(len(piv_smooth.columns))

offset_step_small = np.nanmax(piv_smooth.values) * 0.6
offset_step_big = np.nanmax(piv_smooth.values) * 1.
big_features = ["Temp.", "Precip."]

current_offset = 0.0

# --- Compute max ΔRMSE per feature ---
max_importance = piv_global.max(axis=1)

for i, (feat, color) in enumerate(zip(feat_order, palette)):
    y = piv_smooth.loc[feat].values

    # Plot ridge
    ax.plot(month_idx, y + current_offset, color=color, lw=2)
    ax.fill_between(month_idx,
                    current_offset,
                    y + current_offset,
                    color=color,
                    alpha=0.4)

    # Feature label
    ax.text(-0.6,
            current_offset,
            feat,
            va='center',
            ha='right',
            fontsize=13,
            color='black')

    # --- Add ΔRMSE annotation at ridge's highest point ---
    max_idx = np.argmax(y)
    max_x = month_idx[max_idx]
    max_y = y[max_idx] + current_offset

    # Offset annotation if too close to edges
    if max_x == 0:
        text_x = max_x + 0.1  # shift right if peak is at first position
        ha = 'left'
    elif max_x == len(month_idx) - 1:
        text_x = max_x - 0.4  # shift left if at last position
        ha = 'right'
    else:
        text_x = max_x
        ha = 'center'

    ax.text(text_x,
            max_y + 0.05 * offset_step_small,
            f"ΔRMSE={max_importance[feat]:.3f}",
            ha=ha,
            va='bottom',
            fontsize=11,
            color='black',
            rotation=0,
            bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1.5))

    # Increment vertical offset
    current_offset += offset_step_big if feat in big_features else offset_step_small

# --- Styling ---
ax.set_yticks([])
ax.set_xlim(0, len(month_idx) - 1)
ax.set_xticks(month_idx)
ax.set_xticklabels([m.strip("_").capitalize() for m in piv_smooth.columns],
                   rotation=45,
                   ha='right')
ax.set_xlabel("Month")
ax.set_title("Monthly Permutation Feature Importance", pad=20)

ax.set_facecolor("white")
fig.patch.set_facecolor("white")
ax.tick_params(colors='black')
for spine in ["top", "right", "left"]:
    ax.spines[spine].set_visible(False)
ax.spines["bottom"].set_color("black")

plt.tight_layout()
plt.show()

# --- Save figure ---
fig.savefig("figures/paper/CH_LSTM_monthly_PFI.png",
            dpi=300,
            bbox_inches="tight")