## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
from skorch.helper import SliceDataset
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint

import massbalancemachine as mbm

# Add root of repo to import MBM
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

# vois_topographical = [
#     "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
#     "millan_v", "svf"
# ]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

## Input data:

In [None]:
data_glamos = getStakesData(cfg)

# Number of winter and annual measurements:
print("Number of winter measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['winter'])
print("Number of annual measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['annual'])

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
data_test = test_set['df_X']
data_test['y'] = test_set['y']

### Feature distribution of test set:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
# STATIC_COLS = [
#     'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
#     'millan_v', 'svf'
# ]

STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']
feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(data_train,
                        data_test,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)
# save figure
fig.savefig('figures/paper/fig_tsne_overlap_train_test_CH.png',
            dpi=300,
            bbox_inches='tight')

## LSTM:

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': True,
    'head_dropout': 0.0,
    'loss_spec': None
}

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Evaluate on test
model_filename = f"models/lstm_model_2025-10-22_two_heads_no_oggm.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

## Validation OOS:
Out of sample on test set

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_annual = "#c51b7d"
color_winter = colors[0]

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=color_annual,
    color_winter=color_winter,
)
# save figure
fig.savefig('figures/paper/fig_predvsobs.png', dpi=300, bbox_inches='tight')

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(25, 18), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       subplot_labels=subplot_labels,
                                       color_annual=color_annual,
                                       color_winter=color_winter,
                                       custom_order=test_gl_per_el,
                                       gl_area=gl_area)

axs[3].set_ylabel("Modeled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_predvsobs_indv.png',
            dpi=300,
            bbox_inches='tight')

## Intermediate validation OOS and IS:

In [None]:
# Geodetic MB + per-glacier periods
geodetic_mb = get_geodetic_MB(cfg)
periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

In [None]:
# PATH_PREDICTIONS_LSTM_OOS = os.path.join(cfg.dataPath, "GLAMOS",
#                                          "distributed_MB_grids",
#                                          "MBM/glamos_dems_LSTM_svf_OOS")

# PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
#                                         "distributed_MB_grids",
#                                         "MBM/glamos_dems_LSTM_svf_IS")

PATH_PREDICTIONS_LSTM_OOS = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/testing_LSTM/glamos_dems_LSTM_no_oggm")

PATH_PREDICTIONS_LSTM_IS = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/testing_LSTM/glamos_dems_LSTM_no_oggm_IS")

In [None]:
# Available glaciers (those with LSTM predictions)
glaciers_in_glamos = set(os.listdir(PATH_PREDICTIONS_LSTM_OOS))

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

# Glaciers present in both geodetic periods and predictions, sorted by area (asc)
glacier_list = sorted(
    (g for g in periods_per_glacier.keys() if g in glaciers_in_glamos),
    key=lambda g: gl_area.get(g, 0))
print("Number of glaciers:", len(glacier_list))
print("Glaciers:", glacier_list)

In [None]:
# Run comparison
ds_lstm_OS = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_OOS,
    cfg=cfg,
)

# Drop rows where any required columns are NaN
ds_lstm_OS = ds_lstm_OS.dropna(subset=['Geodetic MB', 'MBM MB'])
ds_lstm_OS = ds_lstm_OS.sort_values(by='Area')
ds_lstm_OS['GLACIER'] = ds_lstm_OS['GLACIER'].apply(lambda x: x.capitalize())

# Run comparison
ds_lstm_IS = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_IS,
    cfg=cfg,
)

# Drop rows where any required columns are NaN
ds_lstm_IS = ds_lstm_IS.dropna(subset=['Geodetic MB', 'MBM MB'])
ds_lstm_IS = ds_lstm_IS.sort_values(by='Area')
ds_lstm_IS['GLACIER'] = ds_lstm_IS['GLACIER'].apply(lambda x: x.capitalize())

### Mass balance gradients:

In [None]:
# Stake data
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

#### Gradients test:

In [None]:
def _alpha_labels(n: int):
    """(a), (b), ... (z), (aa), (ab), ... for n>=1"""

    def to_label(k: int) -> str:
        # 0 -> a, 25 -> z, 26 -> aa, ...
        s = ""
        k += 1
        while k > 0:
            k, r = divmod(k - 1, 26)
            s = chr(97 + r) + s
        return f"({s})"

    return [to_label(i) for i in range(n)]


In [None]:
gl_list = [
    'Plattalva',
    'Hohlaub',
    'Tsanfleuron',
    'Schwarzberg',
    'Forno',
]

nrows = 1  # 0: OOS, 1: IS
ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 7

fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(28 * cm, 12 * cm),
                        dpi=300)

subplot_labels = _alpha_labels(ncols)
for c, gl in enumerate(gl_list):  # columns = glaciers
    # Annual
    df_lstm_a_oos, df_glamos_a_oos, df_all_a_oos = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="annual")
    # Winter
    df_lstm_w_oos, df_glamos_w_oos, df_all_w_oos = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="winter")

    # get unique years
    years = df_all_w_oos.YEAR.unique()

    ax = axs[c]

    # OOS: bands + mean + GLAMOS + stakes
    ax = plot_lstm_by_elevation_periods(df_all_a_oos,
                                        df_all_w_oos,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM OOS',
                                        show_band=True,
                                        color_annual=color_annual,
                                        color_winter=color_winter)

    ax = plot_glamos_by_elevation_periods(df_all_a_oos,
                                          df_all_w_oos,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=color_annual,
                                          color_winter=color_winter)

    # add stakes:
    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=color_annual,
                                          color_winter=color_winter,
                                          marker_size=14)

    ax.set_ylabel('')
    ax.set_xlabel('')

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f'{gl} ({area} km2)', fontsize=fontsize, pad=2)

    # Row label on the left margin (first column only)
    if c == 0:
        ax.set_ylabel(f'Elevation (m a.s.l.)', fontsize=fontsize)

    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=6.5, pad=2)
    ax.set_xlabel('')  # we use a supxlabel below

    # Subplot label (auto or provided)
    ax.text(
        0.02,
        0.98,
        subplot_labels[c],
        transform=ax.transAxes,
        fontsize=fontsize,
        va="top",
        ha="left",
    )

    # remove per-axes legend
    leg = ax.legend()
    if leg is not None:
        leg.remove()

    print(f'{gl}: {years.min()}-{years.max()}')

# Global x-label
fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual,
          alpha=0.25,
          label="LSTM out-of-sample band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM out-of-sample mean (annual)"),
    Patch(facecolor=color_winter,
          alpha=0.25,
          label="LSTM out-of-sample band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM out-of-sample mean (winter)"),

    # GLAMOS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_mb_gradients_testgl_OOS.png',
            dpi=300,
            bbox_inches='tight')

In [None]:
test_gl_area = {}
for x in TEST_GLACIERS:
    test_gl_area[x] = gl_area[x]
test_gl_area = dict(
    sorted(test_gl_area.items(), key=lambda item: item[1], reverse=True))
test_gl_area

## In-sample results:

### Maps: Two glaciers, two years

In [None]:
fig = plot_2glaciers_2years_glamos_vs_lstm(
    glacier_names=("aletsch", "rhone"),
    years_by_glacier=((2014, 2024), (2009, 2024)),
    cfg=cfg,
    df_stakes=df_stakes,
    path_distributed_mb=path_distributed_MB_glamos,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    period="annual",
)

# save figure
fig.savefig('figures/paper/fig_glamos_vs_lstm_aletsch_rhone.png',
            dpi=300,
            bbox_inches='tight')

### Gradients train:

In [None]:
gl_list = [
    'Gries',
    'Gietro',
    'Rhone',
    'Aletsch',
]

nrows = 1  # 0: OOS, 1: IS
ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 7

fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 12 * cm),
                        dpi=300)

subplot_labels = _alpha_labels(ncols)

for c, gl in enumerate(gl_list):  # columns = glaciers
    # Annual
    df_lstm_a_is, df_glamos_a_is, df_all_a_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")
    # Winter
    df_lstm_w_is, df_glamos_w_is, df_all_w_is = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # get unique years
    years = df_all_w_oos.YEAR.unique()

    ax = axs[c]

    # IS: LSTM mean-only overlay (no band), dashed line to distinguish
    ax = plot_lstm_by_elevation_periods(df_all_a_is,
                                        df_all_w_is,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM IS',
                                        show_band=True,
                                        color_annual=color_annual,
                                        color_winter=color_winter)

    ax = plot_glamos_by_elevation_periods(df_all_a_is,
                                          df_all_w_is,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=color_annual,
                                          color_winter=color_winter)

    # add stakes:
    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=color_annual,
                                          color_winter=color_winter,
                                          marker_size=14)

    ax.set_ylabel('')
    ax.set_xlabel('')

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f'{gl} ({area} km2, {years.min()}-{years.max()})',
                 fontsize=fontsize,
                 pad=2)

    # Row label on the left margin (first column only)
    if c == 0:
        ax.set_ylabel(f'Elevation (m a.s.l.)', fontsize=fontsize)

    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=6.5, pad=2)
    ax.set_xlabel('')  # we use a supxlabel below

    # Subplot label (auto or provided)
    ax.text(
        0.02,
        0.98,
        subplot_labels[c],
        transform=ax.transAxes,
        fontsize=fontsize,
        va="top",
        ha="left",
    )

    # remove per-axes legend
    leg = ax.legend()
    if leg is not None:
        leg.remove()

# Global x-label
fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual,
          alpha=0.25,
          label="LSTM in-sample band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (annual)"),
    Patch(facecolor=color_winter,
          alpha=0.25,
          label="LSTM in-sample band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM in-sample mean (winter)"),

    # LSTM IS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=5,
           fontsize=7)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig_mb_gradients_IS.png',
            dpi=300,
            bbox_inches='tight')

### Geodetic MB:

In [None]:
# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(ds_lstm_OS["Geodetic MB"],
                                  ds_lstm_OS["MBM MB"])
corr_nn = np.corrcoef(ds_lstm_OS["Geodetic MB"], ds_lstm_OS["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    ds_lstm_OS,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)

# save figure
fig.savefig('figures/paper/fig_mbm_vs_geodetic_by_area_bin_OOS.png',
            dpi=300,
            bbox_inches='tight')

In [None]:
# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(ds_lstm_IS["Geodetic MB"],
                                  ds_lstm_IS["MBM MB"])
corr_nn = np.corrcoef(ds_lstm_IS["Geodetic MB"], ds_lstm_IS["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    ds_lstm_IS,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)

# save figure
fig.savefig('figures/paper/fig_mbm_vs_geodetic_by_area_bin_IS.png',
            dpi=300,
            bbox_inches='tight')

## Feature importance:

In [None]:
pfi_parallel = permutation_feature_importance_mbm_parallel(
    cfg=cfg,
    custom_params=custom_params,
    model_filename=model_filename,
    df_eval=df_test,  # your eval dataframe WITH TARGETS aligned to predictions
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    ds_train=ds_train,
    train_idx=train_idx,
    target_col="POINT_BALANCE",  # <-- set your target column name
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    seed=cfg.seed,
    n_repeats=5,
    batch_size=256,
    max_workers=None,  # auto: n_cpus-1 (cap 32)
)
plt.figure(figsize=(8, max(3, 0.35 * len(pfi_parallel))))
plt.barh(pfi_parallel["feature"],
         pfi_parallel["mean_delta"],
         xerr=pfi_parallel["std_delta"])
plt.gca().invert_yaxis()
plt.title(
    f"Permutation Feature Importance (Δ{pfi_parallel['metric_name'].iloc[0]}; baseline={pfi_parallel['baseline'].iloc[0]:.3f})"
)
plt.xlabel(
    f"Increase in {pfi_parallel['metric_name'].iloc[0]} (higher = more important)"
)
plt.tight_layout()
plt.show()

## Rest figures:

### Gradients all:

In [None]:
nrows = 4
ncols = 5
cm = 1 / 2.54
fontsize = 7

# Create a figure with the specified number of subplots
fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 15 * cm),
                        dpi=300)
axs = axs.flatten()
gl_list = [
    'schwarzbach', 'murtel', 'plattalva', 'basodino', 'limmern', 'adler',
    'hohlaub', 'albigna', 'tsanfleuron', 'silvretta', 'gries', 'clariden',
    'gietro', 'schwarzberg', 'forno', 'allalin', 'otemma', 'findelen', 'rhone',
    'aletsch'
]
for i, gl in enumerate(gl_list):
    # Annual
    df_lstm_a, df_glamos_a, df_all_a = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")

    years = df_all_a.YEAR.unique()

    # Winter
    df_lstm_w, df_glamos_w, df_all_w = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # if dataframe not None
    if df_all_a.empty:
        print(f"No data for glacier: {gl}")
        continue

    ax = plot_mb_by_elevation_periods_combined(df_all_a,
                                               df_all_w,
                                               df_stakes,
                                               gl.lower(),
                                               ax=axs[i])

    # area = areas_per_gl.loc[gl].Area
    area = gl_area.get(gl.lower(), np.nan)
    if area < 0.1:
        area = np.round(area, 3)
    else:
        area = np.round(area, 1)
    if gl.lower() in TEST_GLACIERS:
        axs[i].set_title(f'*{gl} ({area} km2, {years.min()}-{years.max()})',
                         fontsize=fontsize,
                         pad=2)
    else:
        axs[i].set_title(f'{gl} ({area} km2, {years.min()}-{years.max()})',
                         fontsize=fontsize,
                         pad=2)

    axs[i].grid(alpha=0.2)
    axs[i].tick_params(labelsize=6.5, pad=2)
    axs[i].set_ylabel('')
    axs[i].set_xlabel('')
    # remove legend
    axs[i].legend().remove()

axs[5].set_ylabel('Elevation (m a.s.l.)', fontsize=fontsize)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual, alpha=0.25, label="LSTM band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (annual)"),
    Patch(facecolor=color_winter, alpha=0.25, label="LSTM band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (winter)"),

    # GLAMOS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

# Adjust the layout
plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

### Maps:

In [None]:
# Example usage
GLACIER_NAME = 'schwarzberg'
df_lstm_two_heads_gl = ds_lstm_IS[ds_lstm_IS.GLACIER ==
                                  GLACIER_NAME.capitalize()]

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

plot_scatter_comparison(axs[0],
                        df_lstm_two_heads_gl,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(LSTM two heads)")

# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_lstm = mbm_glwd_pred(PATH_PREDICTIONS_LSTM_IS, GLACIER_NAME)
MBM_glwmb_lstm.rename(columns={"MBM Balance": "MBM Balance LSTM"},
                      inplace=True)

# Merge with GLAMOS data
MBM_glwmb_lstm = MBM_glwmb_lstm.join(GLAMOS_glwmb)

# Plot the data
MBM_glwmb_lstm.plot(ax=axs[1],
                    y=['MBM Balance LSTM', 'GLAMOS Balance'],
                    marker="o",
                    color=[color_annual, color_winter])

axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
axs[1].set_ylabel("Mass Balance [m w.e.]", fontsize=18)
axs[1].set_xlabel("Year", fontsize=18)
axs[1].grid(True, linestyle="--", linewidth=0.5)
axs[1].legend(fontsize=14)
axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier (LSTM)", fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
print(MBM_glwmb_lstm.index)
year_range = range(1956, 1960)
for year in year_range:
    plot_mass_balance_comparison_annual(
        glacier_name=GLACIER_NAME,
        year=year,
        cfg=cfg,
        df_stakes=df_stakes,
        path_distributed_mb=path_distributed_MB_glamos,
        path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
        period='annual')

### Monthly distributions:

In [None]:
glaciers = os.listdir(PATH_PREDICTIONS_LSTM_OOS)
hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]
# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(PATH_PREDICTIONS_LSTM_OOS, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_LSTM = pd.concat(all_glacier_data, axis=0, ignore_index=True)
df_months_LSTM

In [None]:
path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/glamos_dems_NN_SEB_full_OGGM')

glaciers = os.listdir(path_save_glw)
hydro_months = [
    'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul',
    'aug', 'sep'
]
# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(path_save_glw, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_NN = pd.concat(all_glacier_data, axis=0, ignore_index=True)
df_months_NN

#### Glacier-wide:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import joypy
from pandas.api.types import CategoricalDtype


# get glacier-wide MB for every year
glwd_months_NN = df_months_LSTM.groupby(['glacier',
                                         'year']).mean().reset_index()

array_nn, array_xgb, months = [], [], []
month_order = [
    'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct',
    'nov', 'dec'
]
cat_month = CategoricalDtype(month_order, ordered=True)

df_months_nn = glwd_months_NN[month_order]

for col in df_months_nn.columns:
    array_nn.append(df_months_nn[col].values)
    months.append(np.tile(col, len(df_months_nn[col])))

df_months_nn_long = pd.DataFrame(
    data={
        'mb_nn': np.concatenate(np.array(array_nn)),
        'Month': np.concatenate(np.array(months))
    })

# order df_months_nn_long
df_months_nn_long['Month'] = df_months_nn_long['Month'].astype(cat_month)

model_colors = color_winter
alpha = 1

cm = 1 / 2.54
ax, fig = joypy.joyplot(df_months_nn_long,
                        by='Month',
                        column='mb_nn',
                        alpha=0.8,
                        overlap=0,
                        fill=False,
                        linewidth=1.5,
                        xlabelsize=8.5,
                        ylabelsize=8.5,
                        x_range=[-2.2, 2.2],
                        grid=False,
                        color=model_colors,
                        figsize=(12 * cm, 14 * cm),
                        ylim='own')

vline_alpha = 0.5
plt.axvline(x=0, color='grey', alpha=vline_alpha, linewidth=1)

plt.xlabel('Mass balance (m w.e.)', fontsize=8.5)
plt.yticks(ticks=range(1, 13), labels=month_order, fontsize=8.5)
plt.gca().set_yticklabels(month_order)

legend_patches = [
    Patch(facecolor=color, label=model, alpha=alpha, edgecolor='k')
    for model, color in zip(['LSTM'], [model_colors])
]
plt.legend(handles=legend_patches,
           loc='upper center',
           bbox_to_anchor=(0.48, -0.1),
           ncol=4,
           fontsize=8.5,
           handletextpad=0.5,
           columnspacing=1)