## Setting Up:

In [None]:
import os
import warnings
import logging
from collections import defaultdict

import numpy as np
import pandas as pd
import xarray as xr

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from cmcrameri import cm
from matplotlib import gridspec

import massbalancemachine as mbm

from scripts.utils import *
from scripts.glamos import *
from scripts.models import *
from scripts.geo_data import *
from scripts.dataset import *
from scripts.geodetic import *
from scripts.plotting import *


warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
use_mbm_style()

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Read GLAMOS stake data
data_glamos = get_stakes_data(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly = data_monthly[data_monthly['YEAR']
                            < 2025]  # Used elsewhere for validation

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

print('Number of unique stake measurements:', dataloader_gl.data.ID.nunique())

# same for winter and annual
data_winter = dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']
data_annual = dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']

print('Number of unique stake measurements (winter):',
      data_winter.ID.nunique())
print('Number of unique stake measurements (annual):',
      data_annual.ID.nunique())

print('Ratio of winter/annual stake measurements:',
      data_winter.ID.nunique() / data_annual.ID.nunique())

# print size of test and train sets
print('Size of train set:', data_train.shape)
print('Size of test set:', data_test.shape)

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly_Aug_ = data_monthly_Aug_[data_monthly_Aug_['YEAR']
                                      < 2025]  # Used elsewhere for validation

data_test_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    TEST_GLACIERS)]
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]

splits_Aug_, test_set_Aug_, train_set_Aug_ = get_CV_splits(
    dataloader_gl_Aug_,
    test_split_on='GLACIER',
    test_splits=TEST_GLACIERS,
    random_state=cfg.seed)

# # Validation and train split:
data_train_Aug_ = train_set_Aug_['df_X']
data_train_Aug_['y'] = train_set_Aug_['y']
data_test_Aug_ = test_set_Aug_['df_X']
data_test_Aug_['y'] = test_set_Aug_['y']

seed_all(cfg.seed)

df_train_Aug_ = data_train_Aug_.copy()
df_train_Aug_['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test_Aug_ = data_test_Aug_.copy()
df_test_Aug_['PERIOD'] = df_test_Aug_['PERIOD'].str.strip().str.lower()

### Feature distribution of test set:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(data_train,
                        data_test,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)
# save figure
fig.savefig('figures/paper/appendix/app_tsne_overlap_train_test_CH.png',
            dpi=300,
            bbox_inches='tight')

In [None]:
palette = {'Train': color_dark_blue, 'Test': '#b2182b'}
fig = plot_feature_kde_overlap(
    data_train,
    data_test,
    STATIC_COLS + MONTHLY_COLS,
    palette,
    outfile="figures/paper/appendix/app_feature_kde_overlap.png")

## LSTM:

### Build LSTM dataloaders:

In [None]:
CACHE_DIR = "cache"
os.makedirs(CACHE_DIR, exist_ok=True)

CACHE_DS_TRAIN = os.path.join(CACHE_DIR, "ds_train_full_LSTM_figures.pt")
CACHE_DS_TEST = os.path.join(CACHE_DIR, "ds_test_full_LSTM_figures.pt")

RUN_CACHE_DS = False  # ← set True to rebuild & overwrite cache
seed_all(cfg.seed)

if RUN_CACHE_DS or not (os.path.exists(CACHE_DS_TRAIN)
                        and os.path.exists(CACHE_DS_TEST)):
    print("Building train / test datasets from scratch...")

    ds_train = build_combined_LSTM_dataset(
        df_loss=data_train,
        df_full=data_train_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=True,
        expect_target=True,
    )

    ds_test = build_combined_LSTM_dataset(
        df_loss=data_test,
        df_full=data_test_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=True,
        expect_target=True,
    )

    torch.save(ds_train, CACHE_DS_TRAIN)
    torch.save(ds_test, CACHE_DS_TEST)
    print("Datasets cached.")

else:
    print("Loading train / test datasets from cache...")
    ds_train = torch.load(CACHE_DS_TRAIN, map_location="cpu")
    ds_test = torch.load(CACHE_DS_TEST, map_location="cpu")

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    months_head_pad_Aug_, months_tail_pad_Aug_)
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

inspect_LSTM_sample(ds_train, 0, month_labels=month_order)

### Define & train model:

In [None]:
# --- build model, resolve loss, train, reload best ---
custom_params = PARAMS_LSTM_OOS_PAST  # from scripts/config_CH.py
model_filename = LSTM_OOS_NORM_Y_PAST  # from scripts/config_CH.py

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Load and evaluate on test
print('Loading model from:', model_filename)
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

## Validation OOS:
Out of sample on test set

In [None]:
# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

### Fig 4 & 5:

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=COLOR_ANNUAL,
    color_winter=COLOR_WINTER,
)
# save figure
fig.savefig('figures/paper/fig4_predvsobs.png', dpi=300, bbox_inches='tight')

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(2, 4, figsize=(32, 15), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = plot_individual_glacier_pred(test_df_preds,
                                   axs=axs,
                                   subplot_labels=subplot_labels,
                                   color_annual=COLOR_ANNUAL,
                                   color_winter=COLOR_WINTER,
                                   custom_order=test_gl_per_el,
                                   gl_area=gl_area)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

# save figure
fig.savefig('figures/paper/fig5_predvsobs_indv.png',
            dpi=300,
            bbox_inches='tight')

### Appendix: comparison against NN & XGBoost

In [None]:
# Load XGB & NN models:
grouped_ids_xgb = pd.read_csv('logs/grouped_ids_xgb.csv')
scores_annual_xgb, scores_winter_xgb = compute_seasonal_scores(
    grouped_ids_xgb, target_col='target', pred_col='pred')

grouped_ids_NN = pd.read_csv('logs/grouped_ids_NN.csv')
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN, target_col='target', pred_col='pred')

print("XGB Annual scores:", scores_annual_xgb)
print("XGB Winter scores:", scores_winter_xgb)
print("NN Annual scores:", scores_annual_NN)
print("NN Winter scores:", scores_winter_NN)

fig = plot_predictions_three_models_side_by_side(
    test_df_preds,
    grouped_ids_NN,
    grouped_ids_xgb,
    scores_annual,
    scores_winter,
    scores_annual_NN,
    scores_winter_NN,
    scores_annual_xgb,
    scores_winter_xgb,
)
plt.show()

# save figure
fig.savefig('figures/paper/appendix/app_predvsobs_three_models.png',
            dpi=300,
            bbox_inches='tight')

## Intermediate validation OOS and IS:

In [None]:
# Geodetic MB + per-glacier periods
geodetic_mb = get_geodetic_MB(cfg)
periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

PATH_PREDICTIONS_LSTM_OOS = os.path.join(cfg.dataPath, "GLAMOS",
                                         "distributed_MB_grids",
                                         "MBM/paper/LSTM_OOS_NORM_Y_PAST")

PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
                                        "distributed_MB_grids",
                                        "MBM/paper/LSTM_IS_NORM_Y_PAST")

# Available glaciers (those with LSTM predictions)
glaciers_in_glamos = set(os.listdir(PATH_PREDICTIONS_LSTM_OOS))

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

# Glaciers present in both geodetic periods and predictions, sorted by area (asc)
glacier_list = sorted(
    (g for g in periods_per_glacier.keys() if g in glaciers_in_glamos),
    key=lambda g: gl_area.get(g, 0))
print("Number of glaciers:", len(glacier_list))
print("Glaciers:", glacier_list)

In [None]:
# For each glacier and period, compute mean MBM and GLAMOS MBs, attach the
# corresponding geodetic mass balance and its uncertainty (sigma), and return
# a DataFrame of results.

CACHE_DIR = "cache"
os.makedirs(CACHE_DIR, exist_ok=True)

CACHE_LSTM_OS = os.path.join(CACHE_DIR, "geodetic_compare_lstm_OS.parquet")
CACHE_LSTM_IS = os.path.join(CACHE_DIR, "geodetic_compare_lstm_IS.parquet")

RUN_CACHE_GEODETIC = False  # ← set True to recompute & overwrite

if RUN_CACHE_GEODETIC or not os.path.exists(CACHE_LSTM_OS):
    print("Computing geodetic comparison (LSTM OS)...")

    ds_lstm_OS = process_geodetic_mb_comparison(
        glacier_list=glacier_list,
        path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
        periods_per_glacier=periods_per_glacier,
        geoMB_per_glacier=geoMB_per_glacier,
        gl_area=gl_area,
        test_glaciers=TEST_GLACIERS,
        path_predictions=PATH_PREDICTIONS_LSTM_OOS,
        cfg=cfg,
    )

    ds_lstm_OS = (ds_lstm_OS.dropna(
        subset=["Geodetic MB", "MBM MB"]).sort_values(by="Area"))
    ds_lstm_OS["GLACIER"] = ds_lstm_OS["GLACIER"].str.capitalize()

    ds_lstm_OS.to_parquet(CACHE_LSTM_OS)
    print("Cached LSTM OS geodetic comparison.")

else:
    print("Loading cached LSTM OS geodetic comparison...")
    ds_lstm_OS = pd.read_parquet(CACHE_LSTM_OS)

if RUN_CACHE_GEODETIC or not os.path.exists(CACHE_LSTM_IS):
    print("Computing geodetic comparison (LSTM IS)...")

    ds_lstm_IS = process_geodetic_mb_comparison(
        glacier_list=glacier_list,
        path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
        periods_per_glacier=periods_per_glacier,
        geoMB_per_glacier=geoMB_per_glacier,
        gl_area=gl_area,
        test_glaciers=TEST_GLACIERS,
        path_predictions=PATH_PREDICTIONS_LSTM_IS,
        cfg=cfg,
    )

    ds_lstm_IS = (ds_lstm_IS.dropna(
        subset=["Geodetic MB", "MBM MB"]).sort_values(by="Area"))
    ds_lstm_IS["GLACIER"] = ds_lstm_IS["GLACIER"].str.capitalize()

    ds_lstm_IS.to_parquet(CACHE_LSTM_IS)
    print("Cached LSTM IS geodetic comparison.")

else:
    print("Loading cached LSTM IS geodetic comparison...")
    ds_lstm_IS = pd.read_parquet(CACHE_LSTM_IS)

### Mass balance gradients on test:

In [None]:
# Stake data
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

test_gl_area = {}
for x in TEST_GLACIERS:
    test_gl_area[x] = gl_area[x]
test_gl_area = dict(
    sorted(test_gl_area.items(), key=lambda item: item[1], reverse=True))
test_gl_area

In [None]:
# =====================================================
# --- Configuration ---
# =====================================================
gl_list = [
    'Plattalva',
    'Hohlaub',
    'Tsanfleuron',
    'Schwarzberg',
    'Forno',
]
ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 9
tick_fontsize = 8

# =====================================================
# --- Figure + GridSpec setup ---
# =====================================================
fig = plt.figure(figsize=(28 * cm, 13 * cm), dpi=300)

# You can adjust width_ratios for column flexibility
gs = gridspec.GridSpec(
    nrows=1,
    ncols=ncols,
    figure=fig,
    width_ratios=[1] * ncols,
    wspace=0.4,  # horizontal space between subplots
    hspace=0.25,  # vertical spacing (only one row here)
)

subplot_labels = alpha_labels(ncols)

# =====================================================
# --- Panels ---
# =====================================================
for c, gl in enumerate(gl_list):
    ax = fig.add_subplot(gs[0, c])

    # --- Data loading ---
    df_lstm_a_oos, df_glamos_a_oos, df_all_a_oos = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="annual")
    df_lstm_w_oos, df_glamos_w_oos, df_all_w_oos = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="winter")

    years = df_all_w_oos.YEAR.unique()
    # --- Plotting ---
    ax = plot_lstm_by_elevation_periods(df_all_a_oos,
                                        df_all_w_oos,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM OOS',
                                        show_band=True,
                                        color_annual=COLOR_ANNUAL,
                                        color_winter=COLOR_WINTER)

    ax = plot_glamos_by_elevation_periods(df_all_a_oos,
                                          df_all_w_oos,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER)

    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER,
                                          marker_size=14)

    # --- Labels, titles, style ---
    ax.set_ylabel("Elevation (m a.s.l.)" if c == 0 else "", fontsize=fontsize)
    ax.set_xlabel("")  # we use a supxlabel below

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f"{gl} ({years.min()}–{years.max()})",
                 fontsize=fontsize,
                 pad=2)
    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=tick_fontsize, pad=2)

    # --- Subplot label ---
    ax.text(
        0.02,
        0.98,
        subplot_labels[c],
        transform=ax.transAxes,
        fontsize=10,
        va="top",
        ha="left",
    )

    # --- Clean up legends per axis ---
    leg = ax.legend()
    if leg is not None:
        leg.remove()

    print(f"{gl}: {years.min()}–{years.max()}, {area} km²")

# =====================================================
# --- Shared label & legend ---
# =====================================================
fig.supxlabel("Mass balance (m w.e.)", fontsize=fontsize, y=0.05)

handles = [
    # LSTM
    Patch(facecolor=COLOR_ANNUAL,
          alpha=0.25,
          label="MBM out-of-sample band (annual)"),
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle='-',
           label="MBM out-of-sample mean (annual)"),
    Patch(facecolor=COLOR_WINTER,
          alpha=0.25,
          label="MBM out-of-sample band (winter)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle='-',
           label="MBM out-of-sample mean (winter)"),

    # GLAMOS
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_ANNUAL,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_WINTER,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

plt.tight_layout()
plt.show()

## Validation in-sample:

### Maps: Two glaciers, two years

In [None]:
fig, map_axes = plot_2glaciers_2years_glamos_vs_lstm(
    glacier_names=("gries", "rhone"),
    years_by_glacier=((2008, 2022), (2008, 2022)),
    cfg=cfg,
    df_stakes=df_stakes,
    path_distributed_mb=path_distributed_MB_glamos,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    period="annual",
    panel_label_start="a",
)

# save figure
# fig.savefig('figures/paper/fig_glamos_vs_lstm_aletsch_rhone.png',
#             dpi=300,
#             bbox_inches='tight')

### Gradients train:

In [None]:
# =====================================================
# --- Configuration ---
# =====================================================

gl_list = [
    'Gries',
    'Gietro',
    'Findelen',
    'Rhone',
    'Aletsch',
]

ncols = len(gl_list)
cm = 1 / 2.54
fontsize = 9
tick_fontsize = 8

# =====================================================
# --- Figure + GridSpec setup ---
# =====================================================
fig = plt.figure(figsize=(28 * cm, 15 * cm), dpi=300)

# You can adjust width_ratios for column flexibility
gs = gridspec.GridSpec(
    nrows=1,
    ncols=ncols,
    figure=fig,
    width_ratios=[1] * ncols,
    wspace=0.4,  # horizontal space between subplots
    hspace=0.25,  # vertical spacing (only one row here)
)

subplot_labels = alpha_labels(ncols)

# =====================================================
# --- Panels ---
# =====================================================
for c, gl in enumerate(gl_list):
    ax = fig.add_subplot(gs[0, c])

    # --- Data loading ---
    df_lstm_a_oos, df_glamos_a_oos, df_all_a_oos = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="annual")
    df_lstm_w_oos, df_glamos_w_oos, df_all_w_oos = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_OOS, cfg, period="winter")

    years = df_all_w_oos.YEAR.unique()

    # --- Plotting ---
    ax = plot_lstm_by_elevation_periods(df_all_a_oos,
                                        df_all_w_oos,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='LSTM OOS',
                                        show_band=True,
                                        color_annual=COLOR_ANNUAL,
                                        color_winter=COLOR_WINTER)

    ax = plot_glamos_by_elevation_periods(df_all_a_oos,
                                          df_all_w_oos,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER)

    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER,
                                          marker_size=14)

    # --- Labels, titles, style ---
    ax.set_ylabel("Elevation (m a.s.l.)" if c == 0 else "", fontsize=fontsize)
    ax.set_xlabel("")  # we use a supxlabel below

    area = gl_area.get(gl.lower(), np.nan)
    area = np.round(area, 3) if area < 0.1 else np.round(area, 1)

    ax.set_title(f"{gl} ({years.min()}–{years.max()})",
                 fontsize=fontsize,
                 pad=2)
    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=tick_fontsize, pad=2)

    # --- Subplot label ---
    ax.text(
        0.02,
        0.98,
        subplot_labels[c],
        transform=ax.transAxes,
        fontsize=10,
        va="top",
        ha="left",
    )

    # --- Clean up legends per axis ---
    leg = ax.legend()
    if leg is not None:
        leg.remove()

    print(f"{gl}: {years.min()}–{years.max()}, {area} km²")

# =====================================================
# --- Shared label & legend ---
# =====================================================
fig.supxlabel("Mass balance (m w.e.)", fontsize=fontsize, y=0.05)

handles = [
    # LSTM
    Patch(facecolor=COLOR_ANNUAL,
          alpha=0.25,
          label="MBM in-sample band (annual)"),
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle='-',
           label="MBM in-sample mean (annual)"),
    Patch(facecolor=COLOR_WINTER,
          alpha=0.25,
          label="MBM in-sample band (winter)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle='-',
           label="MBM in-sample mean (winter)"),

    # GLAMOS
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_ANNUAL,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_WINTER,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

plt.tight_layout()
plt.show()

# =====================================================
# --- Save ---
# =====================================================
# save figure
# fig.savefig('figures/paper/fig_mb_gradients_IS.png',
#             dpi=300,
#             bbox_inches='tight')

### Combined figure (Fig 6): 

In [None]:
# =====================================================
# --- Helpers ---
# =====================================================
def pick_file_glamos(glacier,
                     year,
                     cfg,
                     path_distributed_MB_glamos,
                     period="annual"):
    """
    Find the GLAMOS grid file for a given glacier and year.
    """
    suffix = "ann" if period == "annual" else "win"
    base = os.path.join(cfg.dataPath, path_distributed_MB_glamos, "GLAMOS",
                        glacier)
    cand_lv95 = os.path.join(base, f"{year}_{suffix}_fix_lv95.grid")
    cand_lv03 = os.path.join(base, f"{year}_{suffix}_fix_lv03.grid")

    if os.path.exists(cand_lv95):
        return cand_lv95, "lv95"
    if os.path.exists(cand_lv03):
        return cand_lv03, "lv03"
    return None, None


def load_glamos_wgs84(glacier,
                      year,
                      cfg,
                      path_distributed_MB_glamos,
                      period="annual"):
    """
    Load GLAMOS mass balance data and convert it to WGS84.
    """
    path, cs = pick_file_glamos(glacier, year, cfg, path_distributed_MB_glamos,
                                period)
    if path is None:
        return None

    meta, arr = load_grid_file(path)
    da = convert_to_xarray_geodata(arr, meta)

    if cs == "lv03":
        return transform_xarray_coords_lv03_to_wgs84(da)
    if cs == "lv95":
        return transform_xarray_coords_lv95_to_wgs84(da)
    return None


def load_lstm_ds(glacier,
                 year,
                 path_pred_lstm,
                 period="annual",
                 smoothing_fn=None):
    """
    Load LSTM prediction dataset for a given glacier and year.
    Optionally apply a smoothing function to the dataset.
    """
    zpath = os.path.join(path_pred_lstm, glacier,
                         f"{glacier}_{year}_{period}.zarr")
    if not os.path.exists(zpath):
        return None

    ds = xr.open_zarr(zpath)
    if smoothing_fn is not None:
        ds = smoothing_fn(ds)
    return ds


def lonlat_names(obj):
    """
    Return longitude and latitude coordinate names for an xarray object.
    """
    coords = getattr(obj, "coords", {})
    if "lon" in coords and "lat" in coords:
        return "lon", "lat"
    if "longitude" in coords and "latitude" in coords:
        return "longitude", "latitude"
    return "lon", "lat"


def stake_overlay_rmse(ax,
                       glacier,
                       year,
                       cmap,
                       norm,
                       da_glamos,
                       ds_lstm,
                       df_stakes,
                       period="annual",
                       which="GLAMOS"):
    """
    Overlay stake data and compute RMSE between measured and modeled mass balance.

    Parameters
    ----------
    ax : matplotlib axis
        Axis on which to plot.
    glacier : str
        Glacier name.
    year : int
        Year of interest.
    cmap : colormap
        Colormap for plotting.
    norm : Normalize
        Normalization for color mapping.
    da_glamos : xarray.DataArray
        GLAMOS distributed MB data.
    ds_lstm : xarray.Dataset
        LSTM predicted dataset.
    df_stakes : pandas.DataFrame
        Stake data table with POINT_LON, POINT_LAT, and POINT_BALANCE columns.
    period : str, optional
        "annual" or "winter" period for filtering stakes.
    which : str, optional
        "GLAMOS" or "LSTM" to specify which dataset to compare to stakes.
    """
    if df_stakes is None:
        return None

    # Select glacier-year subset
    sub = df_stakes[(df_stakes.GLACIER == glacier)
                    & (df_stakes.YEAR == year)].copy()
    if period == "annual" and "PERIOD" in sub.columns:
        sub = sub[sub.PERIOD == "annual"].copy()
    if sub.empty:
        return None

    lx, ly = lonlat_names(
        ds_lstm if which == "LSTM" and ds_lstm is not None else da_glamos)

    # Function to extract mass balance for each stake
    def _get_predicted_mb(lon_name, lat_name, row, ds):
        try:
            return ds.sel(
                {
                    lon_name: row.POINT_LON,
                    lat_name: row.POINT_LAT
                },
                method="nearest").pred_masked.item()  # Convert to scalar
        except Exception:
            print(
                f"Warning: Stake at ({row.POINT_LON}, {row.POINT_LAT}) is out of bounds."
            )
            return np.nan

    def _get_predicted_mb_glamos(lon_name, lat_name, row, ds):
        try:
            return ds.sel({
                lon_name: row.POINT_LON,
                lat_name: row.POINT_LAT
            },
                          method="nearest").item()  # Convert to scalar
        except Exception:
            print(
                f"Warning: Stake at ({row.POINT_LON}, {row.POINT_LAT}) is out of bounds."
            )
            return np.nan

    # Safe evaluation wrappers
    def _safe_pred(ds, row):
        try:
            return _get_predicted_mb(lx, ly, row, ds)
        except Exception:
            return np.nan

    def _safe_glamos(row):
        try:
            return _get_predicted_mb_glamos(lx, ly, row, da_glamos)
        except Exception:
            return np.nan

    # Determine which field to compare
    if which == "GLAMOS":
        sub["FIELD"] = sub.apply(_safe_glamos, axis=1)
    else:
        sub["FIELD"] = (sub.apply(lambda r: _safe_pred(ds_lstm, r), axis=1)
                        if ds_lstm is not None else np.nan)

    hue_col = "POINT_BALANCE" if "POINT_BALANCE" in sub.columns else "FIELD"

    # Scatter overlay
    sns.scatterplot(
        data=sub,
        x="POINT_LON",
        y="POINT_LAT",
        hue=hue_col,
        palette=cmap,
        hue_norm=norm,
        ax=ax,
        s=10,
        legend=False,
    )

    # Compute RMSE
    if "POINT_BALANCE" in sub.columns and not np.all(np.isnan(sub["FIELD"])):
        return root_mean_squared_error(sub["POINT_BALANCE"], sub["FIELD"])
    return None

In [None]:
def plot_pixel_scatter_glamos_vs_mbm(
    da_glamos,
    ds_mbm,
    ax,
    fontsize=8,
):
    """
    Pixel-wise scatter of GLAMOS vs MBM over overlapping pixels.
    """

    if da_glamos is None or ds_mbm is None or "pred_masked" not in ds_mbm:
        ax.set_axis_off()
        return

    lon_g, lat_g = lonlat_names(da_glamos)
    lon_m, lat_m = lonlat_names(ds_mbm)

    mbm_interp = ds_mbm["pred_masked"].interp(
        {
            lon_m: da_glamos[lon_g],
            lat_m: da_glamos[lat_g]
        },
        method="nearest",
    )

    mask = np.isfinite(da_glamos.values) & np.isfinite(mbm_interp.values)
    x = da_glamos.values[mask]
    y = mbm_interp.values[mask]

    if x.size == 0:
        ax.text(0.5, 0.5, "No overlap", ha="center", va="center")
        ax.set_axis_off()
        return

    ax.scatter(x, y, s=4, alpha=0.3, color='#bababa')

    lims = [min(x.min(), y.min()), max(x.max(), y.max())]
    ax.plot(lims, lims, "k--", lw=0.8)

    r = np.corrcoef(x, y)[0, 1]

    ax.set_aspect("equal", adjustable="datalim")
    ax.grid(alpha=0.3)

    ax.text(
        0.05,
        0.05,
        f"N = {x.size}\nr = {r:.2f}",
        transform=ax.transAxes,
        ha="left",
        va="bottom",
        fontsize=fontsize,
    )

    ax.set_xlabel("GLAMOS MB", fontsize=fontsize)
    ax.tick_params(
        axis="both",
        which="both",
        #direction="in",
        labelsize=8,
    )

    ax.set_ylabel("MBM MB", fontsize=fontsize)
    ax.yaxis.set_label_coords(-0.28, 0.63)
    ax.set_title('')

In [None]:
import string

glacier_names = ("gries", "rhone", "aletsch")  # <<< ADDED GRIES
years_by_glacier = ((2008, 2022), (2008, 2022), (2008, 2022))  # <<< 3 rows
grad_glaciers = ["Gries", "Gietro", "Findelen", "Rhone", "Aletsch"]
period = "annual"
cm = 1 / 2.54

ylabel_size = 8
title_size = 9
label_size = 9
text_size = 8
tick_size = 8

# =====================================================
# --- Figure + Grid setup ---
# =====================================================
fig = plt.figure(figsize=(34 * cm, 39 * cm), dpi=300)

# More height for 3 rows; keep gradient row same
gs_global = gridspec.GridSpec(2,
                              1,
                              figure=fig,
                              height_ratios=[3.2, 1],
                              hspace=0.15)

gs_maps = gs_global[0].subgridspec(
    3,  # <<< rows (glaciers): was 2
    7,  # columns
    width_ratios=[1, 1, 0.7, 1, 1, 0.7, 0.05],
    wspace=0.35,
    hspace=0.25,
)

gs_grad = gs_global[1].subgridspec(1,
                                   5,
                                   width_ratios=[1, 1, 1, 1, 1],
                                   wspace=0.3)

label_iter = iter(string.ascii_lowercase)

# =====================================================
# --- ROWS 1–3: GLAMOS vs MBM maps ---
# =====================================================

n_glacier_rows = len(glacier_names)

for r, glacier in enumerate(glacier_names):

    # ---------------------------------------------------
    # --- 1) Compute vmin/vmax across BOTH YEARS ---
    # ---------------------------------------------------
    row_vals = []

    for year in years_by_glacier[r]:
        da_g_tmp = load_glamos_wgs84(
            glacier=glacier,
            year=year,
            cfg=cfg,
            path_distributed_MB_glamos=path_distributed_MB_glamos,
            period=period)

        ds_m_tmp = load_lstm_ds(glacier=glacier,
                                year=year,
                                path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
                                period=period,
                                smoothing_fn=apply_gaussian_filter)

        if da_g_tmp is not None:
            row_vals += [float(da_g_tmp.min()), float(da_g_tmp.max())]

        if ds_m_tmp is not None and "pred_masked" in ds_m_tmp:
            row_vals += [
                float(ds_m_tmp["pred_masked"].min()),
                float(ds_m_tmp["pred_masked"].max())
            ]

    if len(row_vals) == 0:
        continue  # no data for this glacier row

    row_vmin = max(min(row_vals), -12)
    row_vmax = max(row_vals)

    # conditional truncation
    vmax_display = min(row_vmax, 4.0)
    vmin_display = max(row_vmin, -12)

    cmap, norm = get_color_maps(row_vmin, row_vmax)

    # ---------------------------------------------------
    # --- 2) Now plot the two years ---
    # ---------------------------------------------------
    for j, year in enumerate(years_by_glacier[r]):

        # allocate axes
        if j == 0:
            ax_g = fig.add_subplot(gs_maps[r, 0])
            ax_m = fig.add_subplot(gs_maps[r, 1])
            ax_s = fig.add_subplot(gs_maps[r, 2])
            ax_cb = None
        else:
            ax_g = fig.add_subplot(gs_maps[r, 3])
            ax_m = fig.add_subplot(gs_maps[r, 4])
            ax_s = fig.add_subplot(gs_maps[r, 5])
            ax_cb = fig.add_subplot(gs_maps[r, 6])

        # load full datasets
        da_g = load_glamos_wgs84(
            glacier=glacier,
            year=year,
            cfg=cfg,
            path_distributed_MB_glamos=path_distributed_MB_glamos,
            period=period)

        ds_m = load_lstm_ds(glacier=glacier,
                            year=year,
                            path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
                            period=period,
                            smoothing_fn=apply_gaussian_filter)

        # ------------------------------------------
        # --- GLAMOS MAP (single imshow) ---
        # ------------------------------------------
        img_g = None
        if da_g is not None:
            img_g = da_g.plot.imshow(ax=ax_g,
                                     cmap=cmap,
                                     norm=norm,
                                     add_colorbar=False)
            ax_g.set_title(f"{glacier.capitalize()} – GLAMOS ({year})",
                           fontsize=title_size)

            mean_g = float(da_g.mean())
            var_g = float(da_g.var())

            rmse_g = stake_overlay_rmse(ax=ax_g,
                                        glacier=glacier,
                                        year=year,
                                        cmap=cmap,
                                        norm=norm,
                                        da_glamos=da_g,
                                        ds_lstm=ds_m,
                                        df_stakes=df_stakes,
                                        period=period,
                                        which="GLAMOS")

            txt = (f"RMSE: {rmse_g:.2f}\n" if rmse_g else "") + \
                  f"mean MB: {mean_g:.2f}\nvar: {var_g:.2f}"

            if glacier.lower() == "gries":
                x_txt, y_txt = 0.96, 0.04
                ha_txt = "right"
            else:
                x_txt, y_txt = 0.04, 0.03
                ha_txt = "left"

            ax_g.text(x_txt,
                      y_txt,
                      txt,
                      transform=ax_g.transAxes,
                      ha=ha_txt,
                      va="bottom",
                      fontsize=text_size)

            ax_g.grid(alpha=0.2)

        # ------------------------------------------
        # --- MBM MAP (single imshow) ---
        # ------------------------------------------
        img_m = None
        if ds_m is not None and "pred_masked" in ds_m:
            img_m = ds_m["pred_masked"].plot.imshow(ax=ax_m,
                                                    cmap=cmap,
                                                    norm=norm,
                                                    add_colorbar=False)
            ax_m.set_title(f"{glacier.capitalize()} – MBM ({year})",
                           fontsize=title_size)

            mean_m = float(ds_m["pred_masked"].mean())
            var_m = float(ds_m["pred_masked"].var())

            rmse_m = stake_overlay_rmse(ax=ax_m,
                                        glacier=glacier,
                                        year=year,
                                        cmap=cmap,
                                        norm=norm,
                                        da_glamos=da_g,
                                        ds_lstm=ds_m,
                                        df_stakes=df_stakes,
                                        period=period,
                                        which="LSTM")

            txt = (f"RMSE: {rmse_m:.2f}\n" if rmse_m else "") + \
                  f"mean MB: {mean_m:.2f}\nvar: {var_m:.2f}"

            if glacier.lower() == "gries":
                x_txt, y_txt = 0.96, 0.04
                ha_txt = "right"
            else:
                x_txt, y_txt = 0.04, 0.03
                ha_txt = "left"

            ax_m.text(x_txt,
                      y_txt,
                      txt,
                      transform=ax_m.transAxes,
                      ha=ha_txt,
                      va="bottom",
                      fontsize=text_size)
            ax_m.grid(alpha=0.2)

        # ------------------------------------------
        # --- Scatter panel ---
        # ------------------------------------------
        plot_pixel_scatter_glamos_vs_mbm(
            da_glamos=da_g,
            ds_mbm=ds_m,
            ax=ax_s,
            fontsize=text_size,
        )
        ax_s.set_title("GLAMOS vs MBM", fontsize=title_size)

        # ------------------------------------------
        # --- Shared Colorbar on j == 1 ---
        # ------------------------------------------
        if j == 1:
            mappable = img_m if img_m is not None else img_g
            cb = fig.colorbar(mappable, cax=ax_cb)

            # apply conditional truncation:
            cb.ax.set_ylim(vmin_display, vmax_display)

            if min(row_vals) < vmin_display:
                ax_cb.plot([0.5], [-0.02],
                           marker="v",
                           color="black",
                           markersize=4,
                           clip_on=False,
                           transform=ax_cb.transAxes)

            if max(row_vals) > vmax_display:
                ax_cb.plot([0.5], [1.02],
                           marker="^",
                           color="black",
                           markersize=4,
                           clip_on=False,
                           transform=ax_cb.transAxes)

            cb.set_label("Mass Balance [m w.e.]", fontsize=ylabel_size)
            cb.ax.tick_params(labelsize=tick_size)

        # ------------------------------------------
        # --- Axis labels & ticks ---
        # ------------------------------------------
        ax_g.set_ylabel("Latitude", fontsize=ylabel_size)

        # preserve your y-label placement tweak, but generalize a bit:
        if j == 0:
            ax_g.yaxis.set_label_coords(-0.35, 0.63)
        else:
            ax_g.yaxis.set_label_coords(-0.08, 0.63)

        ax_m.set_ylabel("")

        # x labels: now apply to ALL rows (as you did), but if you want only bottom row, switch to r == n_glacier_rows-1
        ax_g.set_xlabel("Longitude", fontsize=ylabel_size)
        ax_m.set_xlabel("Longitude", fontsize=ylabel_size)

        ax_g.tick_params(axis="x", labelsize=tick_size)
        ax_m.tick_params(axis="x", labelsize=tick_size)

        if j == 0:
            ax_g.tick_params(axis="y",
                             left=True,
                             labelleft=True,
                             labelsize=tick_size)
        else:
            ax_g.tick_params(axis="y", left=False, labelleft=False)

        ax_m.tick_params(labelleft=False, labelsize=tick_size)

        # panel letters
        for ax in (ax_g, ax_m, ax_s):
            label = next(label_iter)
            ax.text(
                0.04,
                0.98,
                f"({label})",
                transform=ax.transAxes,
                ha="left",
                va="top",
                fontsize=label_size,
            )

# =====================================================
# --- ROW 3: Gradients (same as before) ---
# =====================================================
for c, gl in enumerate(grad_glaciers):
    ax = fig.add_subplot(gs_grad[0, c])

    df_lstm_a_is, df_glamos_a_is, df_all_a_is = \
        aggregate_gridded_mb_lstm_glamos_by_glacier(gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, "annual")

    df_lstm_w_is, df_glamos_w_is, df_all_w_is = \
        aggregate_gridded_mb_lstm_glamos_by_glacier(gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, "winter")

    # cut all to above >2000
    # df_all_a_is = df_all_a_is[df_all_a_is.YEAR >= 2000]
    # df_all_w_is = df_all_w_is[df_all_w_is.YEAR >= 2000]
    # df_stakes = df_stakes[df_stakes.YEAR >= 2000]
    years = df_all_w_is.YEAR.unique()

    ax = plot_lstm_by_elevation_periods(df_all_a_is,
                                        df_all_w_is,
                                        ax=ax,
                                        mean_linestyle='-',
                                        label_prefix='MBM IS',
                                        show_band=True,
                                        color_annual=COLOR_ANNUAL,
                                        color_winter=COLOR_WINTER)

    ax = plot_glamos_by_elevation_periods(df_all_a_is,
                                          df_all_w_is,
                                          ax=ax,
                                          show_band=False,
                                          label_prefix="GLAMOS",
                                          mean_linestyle=":",
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER)

    ax = plot_stakes_by_elevation_periods(df_stakes,
                                          gl.lower(),
                                          valid_bins=None,
                                          ax=ax,
                                          color_annual=COLOR_ANNUAL,
                                          color_winter=COLOR_WINTER,
                                          marker_size=14)

    ax.set_title(f"{gl} ({years.min()}–{years.max()})",
                 fontsize=title_size,
                 pad=2)
    ax.grid(alpha=0.2)
    ax.tick_params(labelsize=tick_size, pad=2)
    ax.text(0.04,
            0.98,
            f"({next(label_iter)})",
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontsize=label_size)
    ax.set_ylabel("Elevation (m a.s.l.)" if c == 0 else "",
                  fontsize=ylabel_size)
    ax.set_xlabel("Mass balance (m w.e.)", fontsize=ylabel_size)

# =====================================================
# --- Shared label & legend ---
# =====================================================
handles = [
    # LSTM
    Patch(facecolor=COLOR_ANNUAL,
          alpha=0.25,
          label="MBM in-sample band (annual)"),
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle='-',
           label="MBM in-sample mean (annual)"),
    Patch(facecolor=COLOR_WINTER,
          alpha=0.25,
          label="MBM in-sample band (winter)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle='-',
           label="MBM in-sample mean (winter)"),

    # GLAMOS
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_ANNUAL,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_WINTER,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.legend(handles=handles,
           loc='lower center',
           bbox_to_anchor=(0.5, 0.03),
           ncol=4,
           fontsize=10,
           frameon=True)

plt.subplots_adjust(bottom=0.10)

# =====================================================
# --- Final Layout ---
# =====================================================
# plt.tight_layout()
plt.show()

fig.savefig("figures/paper/fig6_IS_extrapolation_combined.pdf",
            dpi=300,
            bbox_inches="tight")

### Geodetic MB:

In [None]:
# Compute RMSE and Pearson correlation
ds_lstm_IS
rmse_nn = root_mean_squared_error(ds_lstm_IS["Geodetic MB"],
                                  ds_lstm_IS["MBM MB"])
corr_nn = np.corrcoef(ds_lstm_IS["Geodetic MB"], ds_lstm_IS["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    ds_lstm_IS,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)

# save figure
fig.savefig('figures/paper/fig7_geodetic_IS.pdf', dpi=300, bbox_inches='tight')

## Rest appendix figures:

### All gradients (IS):

In [None]:
nrows = 4
ncols = 5
cm = 1 / 2.54

fontsize = 12
# Create a figure with the specified number of subplots
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(30 * cm, 24 * cm),  # was (25, 15)
    dpi=300)

axs = axs.flatten()
gl_list = [
    'schwarzbach', 'murtel', 'plattalva', 'basodino', 'limmern', 'adler',
    'hohlaub', 'albigna', 'tsanfleuron', 'silvretta', 'gries', 'clariden',
    'gietro', 'schwarzberg', 'forno', 'allalin', 'otemma', 'findelen', 'rhone',
    'aletsch'
]
for i, gl in enumerate(gl_list):
    # Annual
    df_lstm_a, df_glamos_a, df_all_a = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="annual")

    years = df_all_a.YEAR.unique()

    # Winter
    df_lstm_w, df_glamos_w, df_all_w = aggregate_gridded_mb_lstm_glamos_by_glacier(
        gl.lower(), PATH_PREDICTIONS_LSTM_IS, cfg, period="winter")

    # if dataframe not None
    if df_all_a.empty:
        print(f"No data for glacier: {gl}")
        continue

    ax = plot_mb_by_elevation_periods_combined(df_all_a,
                                               df_all_w,
                                               df_stakes,
                                               gl.lower(),
                                               ax=axs[i])

    # area = areas_per_gl.loc[gl].Area
    axs[i].set_title(f'{gl.capitalize()} ({years.min()}-{years.max()})',
                     fontsize=fontsize,
                     pad=2)

    axs[i].grid(alpha=0.2)
    axs[i].tick_params(labelsize=6.5, pad=2)
    axs[i].set_ylabel('')
    axs[i].set_xlabel('')
    # remove legend
    axs[i].legend().remove()

axs[5].set_ylabel('Elevation (m a.s.l.)', fontsize=fontsize)

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=COLOR_ANNUAL, alpha=0.25, label="LSTM band (annual)"),
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (annual)"),
    Patch(facecolor=COLOR_WINTER, alpha=0.25, label="LSTM band (winter)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (winter)"),

    # GLAMOS (mean only)
    Line2D([0], [0],
           color=COLOR_ANNUAL,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=COLOR_WINTER,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_ANNUAL,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=COLOR_WINTER,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

# Adjust the layout
plt.subplots_adjust(hspace=0.35, wspace=0.35)
plt.show()

# Save figure
fig.savefig('figures/paper/appendix/app_gradients_all.pdf',
            dpi=300,
            bbox_inches='tight')

### Maps:

In [None]:
GLACIER_NAME = 'aletsch'
df_lstm_two_heads_gl = ds_lstm_IS[ds_lstm_IS.GLACIER ==
                                  GLACIER_NAME.capitalize()]

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

plot_scatter_comparison(axs[0],
                        df_lstm_two_heads_gl,
                        GLACIER_NAME,
                        color_mbm=COLOR_ANNUAL,
                        color_glamos=COLOR_WINTER,
                        title_suffix="(LSTM two heads)")

# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_lstm = mbm_glwd_pred(PATH_PREDICTIONS_LSTM_IS, GLACIER_NAME)
MBM_glwmb_lstm.rename(columns={"MBM Balance": "MBM Balance LSTM"},
                      inplace=True)

# Merge with GLAMOS data
MBM_glwmb_lstm = MBM_glwmb_lstm.join(GLAMOS_glwmb)

# Plot the data
MBM_glwmb_lstm.plot(ax=axs[1],
                    y=['MBM Balance LSTM', 'GLAMOS Balance'],
                    marker="o",
                    color=[COLOR_ANNUAL, COLOR_WINTER])

axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
axs[1].set_ylabel("Mass Balance [m w.e.]", fontsize=18)
axs[1].set_xlabel("Year", fontsize=18)
axs[1].grid(True, linestyle="--", linewidth=0.5)
axs[1].legend(fontsize=14)
axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier (LSTM)", fontsize=16)

plt.tight_layout()
plt.show()

#### Gietro 2016:

In [None]:
year = 2016
GLACIER_NAME = 'gietro'

fig, axes = plt.subplots(2, 3, figsize=(20, 10))

# ---- First row: mass balance comparison (3 panels) ----
plot_mass_balance_comparison(
    glacier_name=GLACIER_NAME,
    year=year,
    cfg=cfg,
    df_stakes=df_stakes,
    path_distributed_mb=path_distributed_MB_glamos,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    period='annual',
    fig=fig,
    axes=axes[0, :]  # <-- first row only
)
# Remove empty top-right axis
fig.delaxes(axes[0, 2])

# ---- Second row: topography ----
PATH_GLAMOS_TOPO = path_GLAMOS_topo
PATH_XR_GRIDS = os.path.join(cfg.dataPath, PATH_GLAMOS_TOPO, "xr_masked_grids")
ds = xr.open_zarr(os.path.join(PATH_XR_GRIDS, f'{GLACIER_NAME}_{year}.zarr'))

p_aspect = ds.masked_aspect.plot(ax=axes[1, 0],
                                 cmap='twilight_shifted',
                                 add_colorbar=True)
p_slope = ds.masked_slope.plot(ax=axes[1, 1],
                               cmap='cividis',
                               add_colorbar=True)
p_elev = ds.masked_elev.plot(ax=axes[1, 2], cmap='terrain', add_colorbar=True)

p_aspect.colorbar.set_label("Aspect (°)", fontsize=16)
p_slope.colorbar.set_label("Slope (°)", fontsize=16)
p_elev.colorbar.set_label("Elevation (m a.s.l.)", fontsize=16)

axes[1, 0].set_title("Aspect", fontsize=20)
axes[1, 1].set_title("Slope", fontsize=20)
axes[1, 2].set_title("DEM", fontsize=20)

# ---- Panel labels ----
panel_labels = ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)']
for ax, label in zip(axes.ravel(), panel_labels):
    ax.text(0.02,
            0.98,
            label,
            transform=ax.transAxes,
            fontsize=20,
            va='top',
            ha='left',
            bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=2))

plt.suptitle(f"{GLACIER_NAME.capitalize()} Glacier ({year})", fontsize=22)
plt.tight_layout()
plt.show()

fig.savefig('figures/paper/appendix/app_gietro_2016.png',
            dpi=300,
            bbox_inches='tight')


#### Schwarzberg 2021:

In [None]:
year = 2021
GLACIER_NAME = "schwarzberg"

fig, axes = plt.subplots(2, 3, figsize=(20, 10))

# ---- First row: mass balance comparison (3 panels) ----
plot_mass_balance_comparison_cropped(
    glacier_name=GLACIER_NAME,
    year=year,
    cfg=cfg,
    df_stakes=df_stakes,
    path_distributed_mb=path_distributed_MB_glamos,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    period="annual",
    fig=fig,
    axes=axes[0, :],  # <-- first row only
)

# Remove empty top-right axis (if that panel is unused)
fig.delaxes(axes[0, 2])

# ---- Second row: topography ----
PATH_GLAMOS_TOPO = path_GLAMOS_topo
PATH_XR_GRIDS = os.path.join(cfg.dataPath, PATH_GLAMOS_TOPO, "xr_masked_grids")
ds = xr.open_zarr(os.path.join(PATH_XR_GRIDS, f"{GLACIER_NAME}_{year}.zarr"))

# --- crop to glacier extent to remove empty space ---
# Use masked_elev as glacier footprint: valid pixels are on-glacier
valid = np.isfinite(ds.masked_elev.values)

iy, ix = np.where(valid)
ymin, ymax = iy.min(), iy.max()
xmin, xmax = ix.min(), ix.max()

# small padding (in pixels) for nicer framing
pad = 2
ymin = max(ymin - pad, 0)
xmin = max(xmin - pad, 0)
ymax = min(ymax + pad, valid.shape[0] - 1)
xmax = min(xmax + pad, valid.shape[1] - 1)

ds_crop = ds.isel(lat=slice(ymin, ymax + 1), lon=slice(xmin, xmax + 1))

# Plot cropped aspect, slope, elevation
p_aspect = ds_crop.masked_aspect.plot(ax=axes[1, 0],
                                      cmap="twilight_shifted",
                                      add_colorbar=True)
p_slope = ds_crop.masked_slope.plot(ax=axes[1, 1],
                                    cmap="cividis",
                                    add_colorbar=True)
p_elev = ds_crop.masked_elev.plot(ax=axes[1, 2],
                                  cmap="terrain",
                                  add_colorbar=True)

# Colorbar labels
p_aspect.colorbar.set_label("Aspect (°)", fontsize=16)
p_slope.colorbar.set_label("Slope (°)", fontsize=16)
p_elev.colorbar.set_label("Elevation (m a.s.l.)", fontsize=16)

# Titles
axes[1, 0].set_title("Aspect", fontsize=20)
axes[1, 1].set_title("Slope", fontsize=20)
axes[1, 2].set_title("DEM", fontsize=20)

# ---- Panel labels ----
panel_labels = ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
for ax, label in zip(axes.ravel(), panel_labels):
    # skip deleted axes (Matplotlib keeps it in axes.ravel() sometimes)
    if ax is None or (hasattr(ax, "has_data") and (ax not in fig.axes)):
        continue
    ax.text(
        0.02,
        0.98,
        label,
        transform=ax.transAxes,
        fontsize=20,
        va="top",
        ha="left",
        bbox=dict(facecolor="white", alpha=0.6, edgecolor="none", pad=2),
    )

plt.suptitle(f"{GLACIER_NAME.capitalize()} Glacier ({year})", fontsize=22)
plt.tight_layout()
plt.show()

fig.savefig(
    "figures/paper/appendix/app_schwarzberg_2021.png",
    dpi=300,
    bbox_inches="tight",
)


In [None]:
period = 'annual'
GLACIER_NAME = 'aletsch'
year = 2001
fig, axes = plt.subplots(1, 2, figsize=(14, 7))
plot_mass_balance_comparison(glacier_name=GLACIER_NAME,
                             year=year,
                             cfg=cfg,
                             df_stakes=df_stakes,
                             path_distributed_mb=path_distributed_MB_glamos,
                             path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
                             period='annual',
                             fig=fig,
                             axes=axes)
plt.suptitle(
    f"{GLACIER_NAME.capitalize()} Glacier – {period.capitalize()} MB Comparison ({year})",
    fontsize=18,
)
plt.tight_layout()
plt.show()

In [None]:
import string

def plot_mbm_grids_only(
    glacier_name,
    years,
    cfg,
    path_pred_lstm,
    df_stakes,
    path_distributed_MB_glamos,
    period="annual",
    apply_smoothing=True,
    cm=1 / 2.54,
    ncols=4,
    text_size=8,
    figsize_y=9,
):
    """
    Plot only LSTM MB grids across years, overlay stakes, annotate RMSE/mean,
    add panel labels, and produce a *separate* truncated colorbar figure.
    """

    # ---- gather prediction datasets ----
    datasets = []
    year_list = []
    for year in years:
        zpath = os.path.join(path_pred_lstm, glacier_name,
                             f"{glacier_name}_{year}_{period}.zarr")
        if not os.path.exists(zpath):
            continue
        ds = xr.open_zarr(zpath)
        if apply_smoothing:
            ds = apply_gaussian_filter(ds)
        if "pred_masked" not in ds:
            continue
        datasets.append(ds["pred_masked"])
        year_list.append(year)

    if len(datasets) == 0:
        raise RuntimeError("No LSTM MB grids found.")

    # ---- global min/max across all maps ----
    global_vmin = min(float(da.min()) for da in datasets)
    global_vmax = max(float(da.max()) for da in datasets)
    cmap, norm = get_color_maps(global_vmin, global_vmax)

    # truncation for colorbar display only
    vmin_display = max(global_vmin, -12)
    vmax_display = min(global_vmax, 5)

    # ---- plot layout ----
    n = len(datasets)
    nrows = math.ceil(n / ncols)
    fig, axes = plt.subplots(
        nrows,
        ncols,
        #figsize=(ncols * figsize_y * cm, nrows * 6 * cm),
        figsize=(ncols * figsize_y * cm, nrows * 8 * cm),
        constrained_layout=True,
        dpi=200,
    )
    if nrows == 1:
        axes = axes.reshape(1, -1)

    mappable_for_cb = None
    label_iter = iter(string.ascii_lowercase)  # panel letters a, b, c...

    # ---- per-year loop ----
    for idx, (da_lstm, year) in enumerate(zip(datasets, year_list)):
        r = idx // ncols
        c = idx % ncols
        ax = axes[r, c]

        mappable_for_cb = da_lstm.plot.imshow(ax=ax,
                                              cmap=cmap,
                                              norm=norm,
                                              add_colorbar=False)
        ax.set_aspect("auto")
        ax.set_title(f"{glacier_name.capitalize()} – MBM ({year})",
                     fontsize=text_size + 1)

        # load GLAMOS + full LSTM for stake evaluation
        da_glamos = load_glamos_wgs84(
            glacier=glacier_name,
            year=year,
            cfg=cfg,
            path_distributed_MB_glamos=path_distributed_MB_glamos,
            period=period,
        )
        ds_lstm_full = load_lstm_ds(
            glacier=glacier_name,
            year=year,
            path_pred_lstm=path_pred_lstm,
            period=period,
            smoothing_fn=apply_gaussian_filter if apply_smoothing else None,
        )

        rmse_m = stake_overlay_rmse(
            ax=ax,
            glacier=glacier_name,
            year=year,
            cmap=cmap,
            norm=norm,
            da_glamos=da_glamos,
            ds_lstm=ds_lstm_full,
            df_stakes=df_stakes,
            period=period,
            which="LSTM",
        )

        mean_m = float(da_lstm.mean())
        annotation = ((f"RMSE: {rmse_m:.2f}\n" if rmse_m else "") +
                      f"mean MB: {mean_m:.2f}")
        # ax.text(0.04,
        #         0.03,
        #         annotation,
        #         transform=ax.transAxes,
        #         fontsize=text_size,
        #         ha="left",
        #         va="bottom")

        if glacier_name.lower() == "gries":
            x_txt, y_txt = 0.96, 0.04
            ha_txt = "right"
        else:
            x_txt, y_txt = 0.04, 0.03
            ha_txt = "left"

        ax.text(x_txt,
                y_txt,
                annotation,
                transform=ax.transAxes,
                fontsize=text_size,
                ha=ha_txt,
                va="bottom")

        # PANEL LABEL ✔
        panel_label = next(label_iter)
        ax.text(
            0.03,
            0.97,
            f"({panel_label})",
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontsize=text_size,
        )

        ax.grid(alpha=0.2)
        ax.tick_params(axis="both", labelsize=9)
        ax.set_xlabel("")
        ax.set_ylabel("")

    # ---- turn off unused axes ----
    for j in range(idx + 1, nrows * ncols):
        axes[j // ncols, j % ncols].set_axis_off()

    # =====================================================
    # COLORBAR figure (separate)
    # =====================================================
    fig_cb = plt.figure(figsize=(1.5 * cm, 12 * cm), dpi=200)
    ax_cb = fig_cb.add_axes([0.25, 0.05, 0.35, 0.90])
    cb = fig_cb.colorbar(mappable_for_cb, cax=ax_cb)
    cb.ax.set_ylim(vmin_display, vmax_display)

    # up/down arrows if truncated
    if global_vmin < vmin_display:
        ax_cb.plot([0.5], [-0.02],
                   marker="v",
                   color="black",
                   markersize=4,
                   transform=ax_cb.transAxes,
                   clip_on=False)
    if global_vmax > vmax_display:
        ax_cb.plot([0.5], [1.02],
                   marker="^",
                   color="black",
                   markersize=4,
                   transform=ax_cb.transAxes,
                   clip_on=False)

    cb.set_label("Mass Balance [m w.e.]")

    return fig, fig_cb


In [None]:
fig_maps, fig_cb = plot_mbm_grids_only(
    glacier_name="aletsch",
    years=range(2005, 2025),
    cfg=cfg,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    df_stakes=df_stakes,
    period="annual",
    path_distributed_MB_glamos=path_distributed_MB_glamos,
    text_size=10,
    ncols=4,
    figsize_y=7)
fig_maps.savefig("figures/paper/appendix/mbm_all_aletsch_grids.png",
                 dpi=300,
                 bbox_inches="tight")
fig_cb.savefig("figures/paper/appendix/mbm_all_aletsch_colorbar.pdf",
               dpi=300,
               bbox_inches="tight")

In [None]:
fig_maps, fig_cb = plot_mbm_grids_only(
    glacier_name="rhone",
    years=range(2007, 2025),
    cfg=cfg,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    df_stakes=df_stakes,
    period="annual",
    path_distributed_MB_glamos=path_distributed_MB_glamos,
    text_size=9,
    ncols=4,
    figsize_y=7)
fig_maps.savefig("figures/paper/appendix/mbm_all_rhone_grids.png",
                 dpi=300,
                 bbox_inches="tight")
fig_cb.savefig("figures/paper/appendix/mbm_all_rhone_colorbar.pdf",
               dpi=300,
               bbox_inches="tight")

In [None]:
fig_maps, fig_cb = plot_mbm_grids_only(
    glacier_name="gries",
    years=range(2005, 2025),
    cfg=cfg,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    df_stakes=df_stakes,
    period="annual",
    path_distributed_MB_glamos=path_distributed_MB_glamos,
    text_size=10,
    ncols=4,
    figsize_y=7)
fig_maps.savefig("figures/paper/appendix/mbm_all_gries_grids.png",
                 dpi=300,
                 bbox_inches="tight")
fig_cb.savefig("figures/paper/appendix/mbm_all_gries_colorbar.pdf",
               dpi=300,
               bbox_inches="tight")