## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
from skorch.helper import SliceDataset
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint

import massbalancemachine as mbm

# Add root of repo to import MBM
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

# vois_topographical = [
#     "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
#     "millan_v", "svf"
# ]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

## Input data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
data_test = test_set['df_X']
data_test['y'] = test_set['y']

### Feature distribution of test set:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
# STATIC_COLS = [
#     'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
#     'millan_v', 'svf'
# ]

STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']
feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(data_train,
                        data_test,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter = 1000,
                        random_state=cfg.seed, 
                        custom_palette=custom_palette)
# save figure
fig.savefig('figures/paper/fig_tsne_overlap_train_test_CH.png',
            dpi=300,
            bbox_inches='tight')

## LSTM:

In [None]:
import numpy as np
import pandas as pd
from typing import Sequence, Optional, Tuple
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from scipy.stats import ks_2samp
import matplotlib.pyplot as plt


# ---------- shared preprocessing (fit on train, apply to both) ----------
def fit_transform_train_test(train_df: pd.DataFrame, test_df: pd.DataFrame,
                             cols: Sequence[str]):
    imputer = SimpleImputer(strategy="median")
    scaler = StandardScaler()
    Xtr = scaler.fit_transform(imputer.fit_transform(
        train_df[cols].to_numpy()))
    Xte = scaler.transform(imputer.transform(test_df[cols].to_numpy()))
    return Xtr, Xte


# ---------- (1) Per-feature PSI & KS ----------
def population_stability_index(train_vals, test_vals, bins=10) -> float:
    # Use quantile bins from train to be robust
    qs = np.linspace(0, 1, bins + 1)
    edges = np.unique(np.quantile(train_vals[~np.isnan(train_vals)], qs))
    # Avoid degenerate edges
    if len(edges) < 3:
        return 0.0
    tr_hist, _ = np.histogram(train_vals, bins=edges)
    te_hist, _ = np.histogram(test_vals, bins=edges)
    tr_p = tr_hist / max(tr_hist.sum(), 1)
    te_p = te_hist / max(te_hist.sum(), 1)

    # Stabilize zeros
    tr_p = np.clip(tr_p, 1e-8, 1)
    te_p = np.clip(te_p, 1e-8, 1)
    return float(np.sum((te_p - tr_p) * np.log(te_p / tr_p)))


def per_feature_drift_table(data_train,
                            data_test,
                            cols,
                            bins=10) -> pd.DataFrame:
    rows = []
    for c in cols:
        tr = data_train[c].to_numpy()
        te = data_test[c].to_numpy()
        psi = population_stability_index(tr, te, bins=bins)
        # KS needs finite values
        tr_f = tr[np.isfinite(tr)]
        te_f = te[np.isfinite(te)]
        ks_stat, ks_p = ks_2samp(
            tr_f, te_f) if len(tr_f) > 0 and len(te_f) > 0 else (np.nan,
                                                                 np.nan)
        rows.append({
            "feature": c,
            "PSI": psi,
            "KS_stat": ks_stat,
            "KS_p": ks_p
        })
    df = pd.DataFrame(rows).sort_values("PSI", ascending=False)
    return df


# ---------- (2) Domain classifier ROC AUC ----------
def domain_auc(data_train,
               data_test,
               cols,
               random_state=42,
               cv_splits=5) -> Tuple[float, np.ndarray]:
    Xtr, Xte = fit_transform_train_test(data_train, data_test, cols)
    X = np.vstack([Xtr, Xte])
    y = np.hstack(
        [np.zeros(len(Xtr), dtype=int),
         np.ones(len(Xte), dtype=int)])  # 0=train,1=test

    clf = LogisticRegression(max_iter=200,
                             random_state=random_state,
                             n_jobs=None)
    cv = StratifiedKFold(n_splits=cv_splits,
                         shuffle=True,
                         random_state=random_state)
    aucs = cross_val_score(clf, X, y, scoring="roc_auc", cv=cv)
    return float(aucs.mean()), aucs


# ---------- (3) MMD with RBF kernel (plus permutation p-value) ----------
def _rbf_kernel(X, Y=None, gamma=None):
    if Y is None: Y = X
    if gamma is None:
        # median heuristic
        Z = np.vstack([X, Y])
        dists = np.sum(Z**2, 1,
                       keepdims=True) - 2 * Z @ Z.T + (Z @ Z.T).diagonal()
        med = np.median(dists[dists > 0])
        gamma = 1.0 / (2 * max(med, 1e-12))
    XX = np.exp(-gamma * ((X**2).sum(1, keepdims=True) - 2 * X @ X.T +
                          (X**2).sum(1)))
    YY = np.exp(-gamma * ((Y**2).sum(1, keepdims=True) - 2 * Y @ Y.T +
                          (Y**2).sum(1)))
    XY = np.exp(-gamma * ((X**2).sum(1, keepdims=True) - 2 * X @ Y.T +
                          (Y**2).sum(1)))
    return XX, YY, XY


def mmd_rbf_unbiased(X, Y) -> float:
    n, m = X.shape[0], Y.shape[0]
    XX, YY, XY = _rbf_kernel(X, Y)
    np.fill_diagonal(XX, 0.0)
    np.fill_diagonal(YY, 0.0)
    return float(XX.sum() / (n * (n - 1)) + YY.sum() / (m * (m - 1)) -
                 2 * XY.mean())


def mmd_permutation_test(Xtr,
                         Xte,
                         n_perm=200,
                         random_state=42) -> Tuple[float, float]:
    rng = np.random.RandomState(random_state)
    obs = mmd_rbf_unbiased(Xtr, Xte)
    Z = np.vstack([Xtr, Xte])
    n = Xtr.shape[0]
    cnt = 0
    for _ in range(n_perm):
        rng.shuffle(Z)
        mmd = mmd_rbf_unbiased(Z[:n], Z[n:])
        cnt += (mmd >= obs)
    p = (cnt + 1) / (n_perm + 1)
    return obs, p

In [None]:
feature_cols = STATIC_COLS + MONTHLY_COLS

# 1) Per-feature drift table
drift_df = per_feature_drift_table(data_train,
                                   data_test,
                                   feature_cols,
                                   bins=10)
print(drift_df.head(10))  # sort by PSI descending

# 2) Domain classifier (single scalar)
auc_mean, aucs = domain_auc(data_train, data_test, feature_cols)
print(f"Domain ROC AUC (mean±sd): {auc_mean:.3f} ± {aucs.std():.3f}")
# ~0.5 means test is representative; >0.65 suggests detectable shift.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ------------------------------------------------------------
# 1) PSI bar chart (top-N drifted features)
# ------------------------------------------------------------
def plot_psi_bars(drift_df: pd.DataFrame, top_n: int = 20, figsize=(10, 6)):
    """
    drift_df columns expected: ['feature', 'PSI', 'KS_stat', 'KS_p'].
    Shows top_n features by PSI with drift thresholds.
    """
    df = drift_df.sort_values("PSI", ascending=False).head(top_n).copy()
    y = np.arange(len(df))[::-1]  # plot highest at top

    # color code by PSI thresholds (optional; adjust if you prefer)
    colors = []
    for v in df["PSI"]:
        if v >= 0.25:
            colors.append("#d62728")  # significant
        elif v >= 0.10:
            colors.append("#ff7f0e")  # moderate
        else:
            colors.append("#2ca02c")  # negligible

    plt.figure(figsize=figsize)
    plt.barh(y,
             df["PSI"].values,
             tick_label=df["feature"].values,
             color=colors,
             alpha=0.9)
    plt.axvline(0.10, linestyle="--", linewidth=1, label="PSI = 0.10")
    plt.axvline(0.25, linestyle="--", linewidth=1, label="PSI = 0.25")
    plt.xlabel("Population Stability Index (PSI)")
    plt.title(f"Top {min(top_n, len(drift_df))} features by drift (PSI)")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.show()


# ------------------------------------------------------------
# 2) ECDF overlay (train vs test) for a single feature
# ------------------------------------------------------------
def _ecdf(x):
    x = np.asarray(x)
    x = x[np.isfinite(x)]
    if x.size == 0:
        return np.array([]), np.array([])
    x_sorted = np.sort(x)
    y = np.arange(1, x_sorted.size + 1) / x_sorted.size
    return x_sorted, y


def plot_feature_ecdf(train_df: pd.DataFrame,
                      test_df: pd.DataFrame,
                      feature: str,
                      figsize=(5, 4)):
    """
    Overlay ECDFs (train vs test) for a single feature.
    """
    x_tr, y_tr = _ecdf(train_df[feature].values)
    x_te, y_te = _ecdf(test_df[feature].values)

    plt.figure(figsize=figsize)
    if x_tr.size:
        plt.step(x_tr,
                 y_tr,
                 where="post",
                 label="Train",
                 linewidth=2,
                 alpha=0.8)
    if x_te.size:
        plt.step(x_te,
                 y_te,
                 where="post",
                 label="Test",
                 linewidth=2,
                 alpha=0.8)
    plt.title(f"ECDF: {feature}")
    plt.xlabel(feature)
    plt.ylabel("Cumulative probability")
    plt.legend()
    plt.tight_layout()
    plt.show()


# ------------------------------------------------------------
# 3) Convenience: ECDF overlays for top-k drifted features
# ------------------------------------------------------------
def plot_top_feature_ecdfs(
        data_train: pd.DataFrame,
        data_test: pd.DataFrame,
        drift_df: pd.DataFrame,
        k: int = 6,
        ncols: int = 3,
        figsize=(12, 15),
):
    """
    Makes small multiples of ECDF overlays for the top-k features by PSI.
    """
    top_feats = drift_df.sort_values(
        "PSI", ascending=False)["feature"].head(k).tolist()
    n = len(top_feats)
    ncols = min(ncols, n) if n > 0 else 1
    nrows = int(np.ceil(n / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
    for i, feat in enumerate(top_feats):
        r, c = divmod(i, ncols)
        ax = axes[r, c]
        x_tr, y_tr = _ecdf(data_train[feat].values)
        x_te, y_te = _ecdf(data_test[feat].values)

        if x_tr.size:
            ax.step(x_tr,
                    y_tr,
                    where="post",
                    label="Train",
                    linewidth=2,
                    alpha=0.8)
        if x_te.size:
            ax.step(x_te,
                    y_te,
                    where="post",
                    label="Test",
                    linewidth=2,
                    alpha=0.8)
        ax.set_title(feat)
        ax.set_xlabel("")
        ax.set_ylabel("")

        if r == 0: ax.set_xlabel(feat)
        if c == 0: ax.set_ylabel("CDF")

    # tidy up legends (single shared)
    handles, labels = axes[0, 0].get_legend_handles_labels()
    if handles:
        fig.legend(handles, labels, loc="lower center", ncol=2)
    plt.tight_layout()
    plt.show()


# 1) PSI bar chart
plot_psi_bars(drift_df, top_n=20)

# 2) ECDF overlays for the top-k drifted features
plot_top_feature_ecdfs(data_train, data_test, drift_df, k=14, ncols=3)

# 3) Deep-dive a specific feature
plot_feature_ecdf(data_train, data_test, "ELEVATION_DIFFERENCE")

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': True,
    'head_dropout': 0.0,
    'loss_spec': None
}

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Evaluate on test
model_filename = f"models/lstm_model_2025-10-22_two_heads_no_oggm.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

## Validation OOS:
Out of sample on test set

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_annual = "#c51b7d"
color_winter = colors[0]

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=color_annual,
    color_winter=color_winter,
)
# save figure
fig.savefig('figures/paper/fig_predvsobs.png', dpi=300, bbox_inches='tight')

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(25, 18), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       subplot_labels=subplot_labels,
                                       color_annual=color_annual,
                                       color_winter=color_winter,
                                       custom_order=test_gl_per_el,
                                       gl_area=gl_area)

axs[3].set_ylabel("Modelled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_predvsobs_indv.png',
            dpi=300,
            bbox_inches='tight')

## Intermediate validation OOS and IS:

In [None]:
# Geodetic MB + per-glacier periods
geodetic_mb = get_geodetic_MB(cfg)
periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

In [None]:
PATH_PREDICTIONS_LSTM_OOS = os.path.join(cfg.dataPath, "GLAMOS",
                                         "distributed_MB_grids",
                                         "MBM/glamos_dems_LSTM_svf_OOS")

# PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
#                                         "distributed_MB_grids",
#                                         "MBM/glamos_dems_LSTM_svf_IS")

PATH_PREDICTIONS_LSTM_IS = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/testing_LSTM/glamos_dems_LSTM_no_oggm")

In [None]:
# Available glaciers (those with LSTM predictions)
glaciers_in_glamos = set(os.listdir(PATH_PREDICTIONS_LSTM_OOS))

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

# Glaciers present in both geodetic periods and predictions, sorted by area (asc)
glacier_list = sorted(
    (g for g in periods_per_glacier.keys() if g in glaciers_in_glamos),
    key=lambda g: gl_area.get(g, 0))
print("Number of glaciers:", len(glacier_list))
print("Glaciers:", glacier_list)

In [None]:
# Run comparison
ds_lstm_OS = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_OOS,
    cfg=cfg,
)

# Drop rows where any required columns are NaN
ds_lstm_OS = ds_lstm_OS.dropna(subset=['Geodetic MB', 'MBM MB'])
ds_lstm_OS = ds_lstm_OS.sort_values(by='Area')
ds_lstm_OS['GLACIER'] = ds_lstm_OS['GLACIER'].apply(lambda x: x.capitalize())

In [None]:
# Run comparison
ds_lstm_IS = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_IS,
    cfg=cfg,
)

# Drop rows where any required columns are NaN
ds_lstm_IS = ds_lstm_IS.dropna(subset=['Geodetic MB', 'MBM MB'])
ds_lstm_IS = ds_lstm_IS.sort_values(by='Area')
ds_lstm_IS['GLACIER'] = ds_lstm_IS['GLACIER'].apply(lambda x: x.capitalize())

### Mass balance gradients:

In [None]:
# areas_per_gl = ds_lstm_OS.groupby(
#     'GLACIER').Area.mean().reset_index().set_index('GLACIER')

# Stake data
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

#### Gradients all:

In [None]:
nrows = 4
ncols = 5
cm = 1 / 2.54
fontsize = 7

# bin0 = ['Schwarzbach', 'Sexrouge', 'Murtel']
# bin1 = ['Basodino', 'Adler', 'Hohlaub', 'Silvretta', 'Gries', 'Clariden']
# bin2 = ['Gietro', 'Schwarzberg', 'Allalin']
# bin3 = ['Findelen', 'Rhone', 'Aletsch']
# glaciers = bin0 + bin1 + bin2 + bin3

# Create a figure with the specified number of subplots
fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 15 * cm),
                        dpi=300)
axs = axs.flatten()
gl_list = [
    'schwarzbach', 'murtel', 'plattalva', 'basodino', 'limmern', 'adler',
    'hohlaub', 'albigna', 'tsanfleuron', 'silvretta', 'gries', 'clariden',
    'gietro', 'schwarzberg', 'forno', 'allalin', 'otemma', 'findelen', 'rhone',
    'aletsch'
]
for i, gl in enumerate(gl_list):
    # Annual
    df_lstm_a, df_glamos_a, df_all_a = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")

    years = df_all_a.YEAR.unique()

    # Winter
    df_lstm_w, df_glamos_w, df_all_w = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # if dataframe not None
    if df_all_a.empty:
        print(f"No data for glacier: {gl}")
        continue

    ax = plot_mb_by_elevation_periods_combined(df_all_a,
                                               df_all_w,
                                               df_stakes,
                                               gl.lower(),
                                               ax=axs[i])

    # area = areas_per_gl.loc[gl].Area
    area = gl_area.get(gl.lower(), np.nan)
    if area < 0.1:
        area = np.round(area, 3)
    else:
        area = np.round(area, 1)
    if gl.lower() in TEST_GLACIERS:
        axs[i].set_title(f'*{gl} ({area} km2, {years.min()}-{years.max()})',
                         fontsize=fontsize,
                         pad=2)
    else:
        axs[i].set_title(f'{gl} ({area} km2, {years.min()}-{years.max()})',
                         fontsize=fontsize,
                         pad=2)

    axs[i].grid(alpha=0.2)
    axs[i].tick_params(labelsize=6.5, pad=2)
    axs[i].set_ylabel('')
    axs[i].set_xlabel('')
    # remove legend
    axs[i].legend().remove()

axs[5].set_ylabel('Elevation (m a.s.l.)', fontsize=fontsize)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual, alpha=0.25, label="LSTM band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (annual)"),
    Patch(facecolor=color_winter, alpha=0.25, label="LSTM band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (winter)"),

    # GLAMOS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

# Adjust the layout
plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

#### Gradients comparison:

In [None]:
gl_list = [
    'Hohlaub',
    'Tsanfleuron',
    'Schwarzberg',
    'Forno',
]

nrows = 1  # 0: OOS, 1: IS
ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 7

fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 12 * cm),
                        dpi=300)

for c, gl in enumerate(gl_list):  # columns = glaciers
    # Annual
    df_lstm_a_oos, df_glamos_a_oos, df_all_a_oos = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="annual")
    # Winter
    df_lstm_w_oos, df_glamos_w_oos, df_all_w_oos = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="winter")

    # Annual
    df_lstm_a_is, df_glamos_a_is, df_all_a_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")
    # Winter
    df_lstm_w_is, df_glamos_w_is, df_all_w_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # get unique years
    years = df_all_w_oos.YEAR.unique()

    ax = axs[c]

    # OOS: bands + mean + GLAMOS + stakes
    ax = plot_lstm_by_elevation_periods(df_all_a_oos,
                                        df_all_w_oos,
                                        ax=ax,
                                        mean_linestyle='--',
                                        label_prefix='LSTM OOS',
                                        show_band=False,
                                        color_annual=color_annual,
                                        color_winter=color_winter)

    # IS: LSTM mean-only overlay (no band), dashed line to distinguish
    ax = plot_lstm_by_elevation_periods(df_all_a_is,
                                        df_all_w_is,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM IS',
                                        show_band=True,
                                        color_annual=color_annual,
                                        color_winter=color_winter)

    ax = plot_glamos_by_elevation_periods(df_all_a_oos,
                                          df_all_w_oos,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=color_annual,
                                          color_winter=color_winter)

    # add stakes:
    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=color_annual,
                                          color_winter=color_winter,
                                          marker_size=14)

    ax.set_ylabel('')
    ax.set_xlabel('')

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f'{gl} ({area} km2, {years.min()}-{years.max()})',
                 fontsize=fontsize,
                 pad=2)

    # Row label on the left margin (first column only)
    if c == 0:
        ax.set_ylabel(f'Elevation (m a.s.l.)', fontsize=fontsize)

    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=6.5, pad=2)
    ax.set_xlabel('')  # we use a supxlabel below

    # remove per-axes legend
    leg = ax.legend()
    if leg is not None:
        leg.remove()

# Global x-label
fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual,
          alpha=0.25,
          label="LSTM in-sample band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (annual)"),
    Patch(facecolor=color_winter,
          alpha=0.25,
          label="LSTM in-sample band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (winter)"),
    # LSTM OOS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='--',
           label="LSTM out-of-sample mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='--',
           label="LSTM out-of-sample mean (winter)"),
    # LSTM IS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=5,
           fontsize=7)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_mb_gradients_IS_OOS.png',
            dpi=300,
            bbox_inches='tight')

In [None]:
test_gl_area = {}
for x in TEST_GLACIERS:
    test_gl_area[x] = gl_area[x]
test_gl_area = dict(
    sorted(test_gl_area.items(), key=lambda item: item[1], reverse=True))
test_gl_area

## In-sample results:

### Maps:

In [None]:
PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, "GLAMOS",
                                   "distributed_MB_grids",
                                   "MBM/glamos_dems_NN_SEB_full_OGGM")

df_nn = process_geodetic_mass_balance_comparison(
    glacier_list=os.listdir(PATH_PREDICTIONS_NN),
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_NN,
    cfg=cfg,
)

In [None]:
# Example usage
GLACIER_NAME = 'gietro'
# bias_gl = df[df.GLACIER == GLACIER_NAME.capitalize()].bias_gl.unique()[0]
df_lstm_two_heads_gl = ds_lstm_IS[ds_lstm_IS.GLACIER ==
                                  GLACIER_NAME.capitalize()]
df_nn_gl = df_nn[df_nn.GLACIER == GLACIER_NAME]

fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)

plot_scatter_comparison(axs[0],
                        df_lstm_two_heads_gl,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(LSTM two heads)")
plot_scatter_comparison(axs[1],
                        df_nn_gl,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(MLP)")

plt.tight_layout()
plt.show()

In [None]:
# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_nn = mbm_glwd_pred(PATH_PREDICTIONS_NN, GLACIER_NAME)
MBM_glwmb_nn.rename(columns={"MBM Balance": "MBM Balance MLP"}, inplace=True)

MBM_glwmb_lstm = mbm_glwd_pred(PATH_PREDICTIONS_LSTM_IS, GLACIER_NAME)
MBM_glwmb_lstm.rename(columns={"MBM Balance": "MBM Balance LSTM"},
                      inplace=True)

# Merge with GLAMOS data
MBM_glwmb_nn = MBM_glwmb_nn.join(GLAMOS_glwmb)
MBM_glwmb_nn = MBM_glwmb_nn.dropna()

MBM_glwmb = MBM_glwmb_nn.join(MBM_glwmb_lstm)

# Plot the data
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
MBM_glwmb.plot(ax=axs[0],
               y=['MBM Balance LSTM', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])
MBM_glwmb.plot(ax=axs[1],
               y=['MBM Balance MLP', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])

for ax in axs:
    ax.set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
    ax.set_ylabel("Mass Balance [m w.e.]", fontsize=18)
    ax.set_xlabel("Year", fontsize=18)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend(fontsize=14)

axs[0].set_title(f"{GLACIER_NAME.capitalize()} Glacier (LSTM)", fontsize=16)
axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier (MLP)", fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
for year in MBM_glwmb_nn.index:
    plot_mass_balance_comparison_annual(
        glacier_name=GLACIER_NAME,
        year=year,
        cfg=cfg,
        df_stakes=df_stakes,
        path_distributed_mb=path_distributed_MB_glamos,
        path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
        path_pred_nn=PATH_PREDICTIONS_NN,
        period='annual')


#### Two glaciers, two years:

In [None]:
fig = plot_2glaciers_2years_glamos_vs_lstm(
    glacier_names=("aletsch", "rhone"),
    years_by_glacier=((2014, 2024), (2009, 2024)),
    cfg=cfg,
    df_stakes=df_stakes,
    path_distributed_mb=path_distributed_MB_glamos,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    period="annual",
)

# save figure
fig.savefig('figures/paper/fig_glamos_vs_lstm_aletsch_rhone.png',
            dpi=300,
            bbox_inches='tight')

### Gradients:

In [None]:
gl_list = [
    'Gries',
    'Gietro',
    'Rhone',
    'Aletsch',
]

nrows = 1  # 0: OOS, 1: IS
ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 7

fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 12 * cm),
                        dpi=300)

for c, gl in enumerate(gl_list):  # columns = glaciers
    # Annual
    df_lstm_a_is, df_glamos_a_is, df_all_a_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")
    # Winter
    df_lstm_w_is, df_glamos_w_is, df_all_w_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # get unique years
    years = df_all_w_oos.YEAR.unique()

    ax = axs[c]

    # IS: LSTM mean-only overlay (no band), dashed line to distinguish
    ax = plot_lstm_by_elevation_periods(df_all_a_is,
                                        df_all_w_is,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM IS',
                                        show_band=True,
                                        color_annual=color_annual,
                                        color_winter=color_winter)

    ax = plot_glamos_by_elevation_periods(df_all_a_is,
                                          df_all_w_is,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=color_annual,
                                          color_winter=color_winter)

    # add stakes:
    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=color_annual,
                                          color_winter=color_winter,
                                          marker_size=14)

    ax.set_ylabel('')
    ax.set_xlabel('')

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f'{gl} ({area} km2, {years.min()}-{years.max()})',
                 fontsize=fontsize,
                 pad=2)

    # Row label on the left margin (first column only)
    if c == 0:
        ax.set_ylabel(f'Elevation (m a.s.l.)', fontsize=fontsize)

    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=6.5, pad=2)
    ax.set_xlabel('')  # we use a supxlabel below

    # remove per-axes legend
    leg = ax.legend()
    if leg is not None:
        leg.remove()

# Global x-label
fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual,
          alpha=0.25,
          label="LSTM in-sample band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (annual)"),
    Patch(facecolor=color_winter,
          alpha=0.25,
          label="LSTM in-sample band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (winter)"),

    # LSTM IS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=5,
           fontsize=7)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_mb_gradients_IS.png',
            dpi=300,
            bbox_inches='tight')

### Geodetic MB:

In [None]:
# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(ds_lstm_OS["Geodetic MB"],
                                  ds_lstm_OS["MBM MB"])
corr_nn = np.corrcoef(ds_lstm_OS["Geodetic MB"], ds_lstm_OS["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    ds_lstm_OS,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)

In [None]:
# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(ds_lstm_IS["Geodetic MB"],
                                  ds_lstm_IS["MBM MB"])
corr_nn = np.corrcoef(ds_lstm_IS["Geodetic MB"], ds_lstm_IS["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    ds_lstm_IS,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)

# save figure
fig.savefig('figures/paper/fig_mbm_vs_geodetic_by_area_bin_IS.png',
            dpi=300,
            bbox_inches='tight')

## Feature importance:

In [None]:
# Parallel PFI for MBM LSTM (CPU, Linux)
import os, sys, random, numpy as np, pandas as pd, torch, xarray as xr
from typing import Dict, List, Callable, Any, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
import multiprocessing as mp
from tqdm.auto import tqdm

# ------------------------------- determinism helpers -------------------------------


def _set_cpu_env_threads():
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")  # force CPU
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("MKL_NUM_THREADS", "1")
    os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
    os.environ.setdefault("NUMEXPR_MAX_THREADS", "1")
    try:
        torch.set_num_threads(1)
    except Exception:
        pass


def _set_seeds(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # deterministic CPU path
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass


# ------------------------------- globals for worker -------------------------------
# Filled by initializer once per process
_PFI_G = {
    "model": None,
    "device": None,
    "ds_train_copy": None,
    "MONTHLY_COLS": None,
    "STATIC_COLS": None,
    "months_head_pad": None,
    "months_tail_pad": None,
    "target_col": None,
    "baseline": None,
    "df_eval": None,
    "seed": None,
    "batch_size": None,
}


def _pfi_worker_init(cfg, custom_params: Dict[str, Any], model_filename: str,
                     ds_train, train_idx, MONTHLY_COLS: List[str],
                     STATIC_COLS: List[str], months_head_pad: int,
                     months_tail_pad: int, df_eval: pd.DataFrame,
                     target_col: str, seed: int, batch_size: int):
    """Initializer: quiet stdout, set threads, build scalers, load model, compute baseline once."""
    # local import avoids pickling a module from parent
    import massbalancemachine as mbm

    # silence worker prints
    sys.stdout = open(os.devnull, "w")
    sys.stderr = open(os.devnull, "w")

    _set_cpu_env_threads()
    _set_seeds(seed)

    # Fit scalers on TRAIN only
    ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train)
    ds_train_copy.fit_scalers(train_idx)

    # Load model on CPU
    device = torch.device("cpu")
    model = mbm.models.LSTM_MB.build_model_from_params(cfg,
                                                       custom_params,
                                                       device,
                                                       verbose=False)
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state)
    model.eval()

    # Build eval ds/loader with targets
    ds_eval = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_eval,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
        show_progress=False)
    dl_eval = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_eval, ds_train_copy, seed=seed, batch_size=batch_size)
    with torch.no_grad():
        df_pred_base = model.predict_with_keys(device, dl_eval, ds_eval)

    # Try to pick targets directly from pred df, otherwise from df_eval by ID merge
    if _PFI_G["target_col"] is None:
        pass  # set below
    y_true = None
    if target_col in df_pred_base.columns:
        y_true = df_pred_base[target_col].to_numpy()
    else:
        merged = df_pred_base.merge(df_eval[["ID", target_col
                                             ]].drop_duplicates("ID"),
                                    on="ID",
                                    how="left")
        if merged[target_col].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval contains per-ID targets."
            )
        y_true = merged[target_col].to_numpy()

    y_pred = df_pred_base["pred"].to_numpy()
    baseline = float(np.sqrt(np.mean((y_true - y_pred)**2)))

    # store globals
    _PFI_G.update(
        dict(model=model,
             device=device,
             ds_train_copy=ds_train_copy,
             MONTHLY_COLS=MONTHLY_COLS,
             STATIC_COLS=STATIC_COLS,
             months_head_pad=months_head_pad,
             months_tail_pad=months_tail_pad,
             target_col=target_col,
             baseline=baseline,
             df_eval=df_eval,
             seed=seed,
             batch_size=batch_size))


def _permute_within_groups(values: np.ndarray, groups: np.ndarray,
                           rng: np.random.Generator) -> np.ndarray:
    out = np.empty_like(values)
    # group by group label; shuffle within each block
    # to preserve seasonal distribution
    u, inv = np.unique(groups, return_inverse=True)
    for gi, g in enumerate(u):
        idx = np.where(inv == gi)[0]
        shuf = idx.copy()
        rng.shuffle(shuf)
        out[idx] = values[shuf]
    return out


def _pfi_worker_task(task: Tuple[str, str, int]) -> Tuple[str, str, float]:
    """
    Task = (feature, type, repeat_seed_offset) -> returns (feature, type, delta_rmse)
    """
    # local import
    import massbalancemachine as mbm

    feat, ftype, seed_offset = task
    rng = np.random.default_rng(int(_PFI_G["seed"]) + int(seed_offset))

    df = _PFI_G["df_eval"].copy()
    if ftype == "monthly":
        if "MONTHS" not in df.columns:
            raise ValueError(
                "MONTHS column required for monthly feature permutation.")
        df[feat] = _permute_within_groups(df[feat].to_numpy(),
                                          df["MONTHS"].to_numpy(), rng)
    elif ftype == "static":
        idx = np.arange(len(df))
        rng.shuffle(idx)
        df[feat] = df[feat].to_numpy()[idx]
    else:
        raise ValueError("ftype must be 'monthly' or 'static'.")

    # Rebuild ds/loader for permuted df
    ds_p = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df,
        _PFI_G["MONTHLY_COLS"],
        _PFI_G["STATIC_COLS"],
        months_tail_pad=_PFI_G["months_tail_pad"],
        months_head_pad=_PFI_G["months_head_pad"],
        expect_target=True,
        show_progress=False)
    dl_p = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_p,
        _PFI_G["ds_train_copy"],
        seed=_PFI_G["seed"],
        batch_size=_PFI_G["batch_size"])
    with torch.no_grad():
        df_pred = _PFI_G["model"].predict_with_keys(_PFI_G["device"], dl_p,
                                                    ds_p)

    # Targets
    tcol = _PFI_G["target_col"]
    if tcol in df_pred.columns:
        y_true = df_pred[tcol].to_numpy()
    else:
        merged = df_pred.merge(df[["ID", tcol]].drop_duplicates("ID"),
                               on="ID",
                               how="left")
        if merged[tcol].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval contains per-ID targets."
            )
        y_true = merged[tcol].to_numpy()
    y_pred = df_pred["pred"].to_numpy()

    rmse = float(np.sqrt(np.mean((y_true - y_pred)**2)))
    delta = rmse - float(_PFI_G["baseline"])
    return feat, ftype, delta


# ------------------------------- user-facing function -------------------------------


def permutation_feature_importance_mbm_parallel(
    cfg,
    custom_params: Dict[str, Any],
    model_filename: str,
    df_eval: pd.DataFrame,
    MONTHLY_COLS: List[str],
    STATIC_COLS: List[str],
    ds_train,
    train_idx,
    target_col: str,
    months_head_pad: int,
    months_tail_pad: int,
    seed: int = 42,
    n_repeats: int = 5,
    batch_size: int = 256,
    max_workers: int = None,
) -> pd.DataFrame:
    """
    Parallel Permutation Feature Importance (CPU).
    Returns DataFrame: ['feature','type','mean_delta','std_delta','baseline','metric_name'].
    """

    # Build list of all tasks (feature x repeat)
    feats = [(f, "monthly") for f in MONTHLY_COLS] + [(f, "static")
                                                      for f in STATIC_COLS]
    tasks = []
    for feat, ftype in feats:
        for r in range(n_repeats):
            tasks.append((feat, ftype, r))

    # Use Linux fork so df_eval stays shared copy-on-write
    ctx = mp.get_context("fork")
    if max_workers is None:
        max_workers = min(max(1, (os.cpu_count() or 2) - 1), 32)

    # Run
    results = []
    with ProcessPoolExecutor(
            max_workers=max_workers,
            mp_context=ctx,
            initializer=_pfi_worker_init,
            initargs=(
                cfg,
                custom_params,
                model_filename,
                ds_train,
                train_idx,
                MONTHLY_COLS,
                STATIC_COLS,
                months_head_pad,
                months_tail_pad,
                df_eval,
                target_col,
                seed,
                batch_size,
            ),
    ) as ex:
        futs = [ex.submit(_pfi_worker_task, t) for t in tasks]
        for fut in tqdm(as_completed(futs),
                        total=len(futs),
                        desc=f"PFI (workers={max_workers})"):
            results.append(fut.result())

    # Aggregate
    rows = []
    baseline = _PFI_G.get(
        "baseline")  # won't be set in parent; recompute baseline here:
    # baseline computed inside workers, but not shared; recompute baseline serially in parent:
    # Minimal recompute on parent with single pass:

    import massbalancemachine as mbm
    _set_cpu_env_threads()
    _set_seeds(seed)
    ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train)
    ds_train_copy.fit_scalers(train_idx)
    device = torch.device("cpu")
    model = mbm.models.LSTM_MB.build_model_from_params(cfg,
                                                       custom_params,
                                                       device,
                                                       verbose=False)
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state)
    model.eval()
    ds_eval_parent = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_eval,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
        show_progress=False)
    dl_eval_parent = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_eval_parent, ds_train_copy, seed=seed, batch_size=batch_size)
    with torch.no_grad():
        df_pred_base = model.predict_with_keys(device, dl_eval_parent,
                                               ds_eval_parent)
    if target_col in df_pred_base.columns:
        y_true = df_pred_base[target_col].to_numpy()
    else:
        merged = df_pred_base.merge(df_eval[["ID", target_col
                                             ]].drop_duplicates("ID"),
                                    on="ID",
                                    how="left")
        if merged[target_col].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval has per-ID targets."
            )
        y_true = merged[target_col].to_numpy()
    y_pred = df_pred_base["pred"].to_numpy()
    baseline = float(np.sqrt(np.mean((y_true - y_pred)**2)))

    # Build table
    import collections
    bucket: Dict[Tuple[str, str], List[float]] = collections.defaultdict(list)
    for feat, ftype, delta in results:
        bucket[(feat, ftype)].append(float(delta))

    for (feat, ftype), deltas in bucket.items():
        mu = float(np.mean(deltas))
        sd = float(np.std(deltas, ddof=1)) if len(deltas) > 1 else 0.0
        rows.append({
            "feature": feat,
            "type": ftype,
            "mean_delta": mu,
            "std_delta": sd
        })

    out = pd.DataFrame(rows).sort_values(
        "mean_delta", ascending=False).reset_index(drop=True)
    out["baseline"] = baseline
    out["metric_name"] = "rmse"
    return out

In [None]:
pfi_parallel = permutation_feature_importance_mbm_parallel(
    cfg=cfg,
    custom_params=custom_params,
    model_filename=model_filename,
    df_eval=df_test,  # your eval dataframe WITH TARGETS aligned to predictions
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    ds_train=ds_train,
    train_idx=train_idx,
    target_col="POINT_BALANCE",  # <-- set your target column name
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    seed=cfg.seed,
    n_repeats=5,
    batch_size=256,
    max_workers=None,  # auto: n_cpus-1 (cap 32)
)

# Optional: quick plot
import matplotlib.pyplot as plt

plt.figure(figsize=(8, max(3, 0.35 * len(pfi_parallel))))
plt.barh(pfi_parallel["feature"],
         pfi_parallel["mean_delta"],
         xerr=pfi_parallel["std_delta"])
plt.gca().invert_yaxis()
plt.title(
    f"Permutation Feature Importance (Δ{pfi_parallel['metric_name'].iloc[0]}; baseline={pfi_parallel['baseline'].iloc[0]:.3f})"
)
plt.xlabel(
    f"Increase in {pfi_parallel['metric_name'].iloc[0]} (higher = more important)"
)
plt.tight_layout()
plt.show()