## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
from skorch.helper import SliceDataset
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint

import massbalancemachine as mbm

# Add root of repo to import MBM
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Input data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
data_test = test_set['df_X']
data_test['y'] = test_set['y']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi',
    'slope_sgi',
    'hugonnet_dhdt',
    'consensus_ice_thickness',
    'millan_v',
]
feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
custom_params = {
    'Fm': len(MONTHLY_COLS),
    'Fs': len(STATIC_COLS),
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0001,
    'loss_name': 'neutral',
    'loss_spec': None,
    'two_heads': True,
    'head_dropout': 0.0
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_two_heads.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Evaluate on test
model_filename = 'models/lstm_model_2025-10-09_two_heads.pt'
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))


## Validation on PMB:

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_annual = colors[0]
color_winter = "#c51b7d"

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(25, 18), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       subplot_labels=subplot_labels,
                                       color_annual=color_dark_blue,
                                       color_winter=color_pink,
                                       custom_order=test_gl_per_el)

axs[3].set_ylabel("Modelled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

## Glacier-wide MB:

In [None]:
PATH_PREDICTIONS_LSTM_two_heads = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/glamos_dems_LSTM_two_heads")
PATH_PREDICTIONS_NN = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_combis/glamos_dems_NN_SEB_full_OGGM')

# Available glaciers (those with LSTM predictions)
glaciers_in_glamos = set(os.listdir(PATH_PREDICTIONS_LSTM_two_heads))

# Geodetic MB + per-glacier periods
geodetic_mb = get_geodetic_MB(cfg)
periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

# Glaciers present in both geodetic periods and predictions, sorted by area (asc)
glacier_list = sorted(
    (g for g in periods_per_glacier.keys() if g in glaciers_in_glamos),
    key=lambda g: gl_area.get(g, 0))
print("Number of glaciers:", len(glacier_list))
print("Glaciers:", glacier_list)

# Run comparison
df_lstm_two_heads = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_two_heads,
    cfg=cfg,
)

df_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=os.path.join(cfg.dataPath, path_SMB_GLAMOS_csv),
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_NN,
    cfg=cfg,
)

### Geodetic:

In [None]:
# Drop rows where any required columns are NaN
df_lstm_two_heads = df_lstm_two_heads.dropna(subset=['Geodetic MB', 'MBM MB'])
df_lstm_two_heads = df_lstm_two_heads.sort_values(by='Area')
df_lstm_two_heads['GLACIER'] = df_lstm_two_heads['GLACIER'].apply(
    lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_lstm_two_heads["Geodetic MB"],
                                  df_lstm_two_heads["MBM MB"])
corr_nn = np.corrcoef(df_lstm_two_heads["Geodetic MB"],
                      df_lstm_two_heads["MBM MB"])[0, 1]

plot_mbm_vs_geodetic_by_area_bin(df_lstm_two_heads,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)

#### Correct for bias:

In [None]:
# --- Prep (same as before, but copy to avoid SettingWithCopy) ---
df = df_lstm_two_heads.dropna(subset=["Geodetic MB", "MBM MB"]).copy()
df["GLACIER"] = df["GLACIER"].str.capitalize()
df = df.sort_values(by="Area")

# --- Per-glacier bias and corrected predictions ---
# bias_g = E[MBM - Geodetic | glacier]
df["bias_gl"] = (df["MBM MB"] - df["Geodetic MB"]).groupby(
    df["GLACIER"]).transform("mean")
df["MBM MB_corr"] = df["MBM MB"] - df["bias_gl"]

# --- Metrics (original vs corrected) ---
rmse_nn = root_mean_squared_error(df["Geodetic MB"], df["MBM MB"])
corr_nn = np.corrcoef(df["Geodetic MB"], df["MBM MB"])[0, 1]

rmse_corr = root_mean_squared_error(df["Geodetic MB"], df["MBM MB_corr"])
corr_corr = np.corrcoef(df["Geodetic MB"], df["MBM MB_corr"])[0, 1]

print(f"Original  RMSE={rmse_nn:.3f}, r={corr_nn:.3f}")
print(f"Corrected RMSE={rmse_corr:.3f}, r={corr_corr:.3f}")

# --- Replot using your existing function ---
# If plot_mbm_vs_geodetic_by_area_bin expects the column name "MBM MB",
# make a copy with that column replaced by the corrected series.
df_corr = df.copy()
df_corr["MBM MB"] = df_corr["MBM MB_corr"]

plot_mbm_vs_geodetic_by_area_bin(
    df_corr,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=["<1", "1-5", "5–10", ">10", ">100"],
    max_bins=4,
)

In [None]:
# Ordered categorical bins
bins = [0, 1, 5, 10, 100, np.inf]
labels = ['<1', '1-5', '5–10', '>10', '>100']
df_lstm_two_heads = df_lstm_two_heads.replace([np.inf, -np.inf], np.nan)

df_lstm_two_heads["Area_bin"] = pd.cut(
    df_lstm_two_heads["Area"],
    bins=bins,
    labels=labels,
    right=False,
    include_lowest=True,
    ordered=True,
)
categories = list(df_lstm_two_heads["Area_bin"].cat.categories)
bins_in_use = [
    b for b in categories if (df_lstm_two_heads["Area_bin"] == b).any()
]

for i in range(4):
    print(
        f'bin {i}:',
        df_lstm_two_heads.groupby(
            by="Area_bin").GLACIER.unique().reset_index().GLACIER.iloc[i])

In [None]:
areas_per_gl = df_lstm_two_heads.groupby(
    'GLACIER').Area.mean().reset_index().set_index('GLACIER')
areas_per_gl

### Mass balance gradients:

In [None]:
elv_per_id = data_monthly.groupby('ID').POINT_ELEVATION.mean()
df_pred = test_df_preds.merge(elv_per_id,
                              left_on='ID',
                              right_index=True,
                              how='left')

# Stake data
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

In [None]:
bin0 = ['Schwarzbach', 'Sexrouge', 'Murtel']
bin1 = ['Basodino', 'Adler', 'Hohlaub', 'Silvretta', 'Gries', 'Clariden']
bin2 = ['Gietro', 'Schwarzberg', 'Allalin']
bin3 = ['Findelen', 'Rhone', 'Aletsch']

nrows = 3
ncols = 5
cm = 1 / 2.54
fontsize = 7

# Create a figure with the specified number of subplots
fig, axs = plt.subplots(nrows=nrows,
                        ncols=ncols,
                        figsize=(25 * cm, 15 * cm),
                        dpi=300)
axs = axs.flatten()
for i, gl in enumerate(bin0 + bin1 + bin2 + bin3):
    # Annual
    df_lstm_a, df_glamos_a, df_all_a = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_two_heads, cfg, period="annual")

    # Winter
    df_lstm_w, df_glamos_w, df_all_w = build_all_years_df(
        gl.lower(), PATH_PREDICTIONS_LSTM_two_heads, cfg, period="winter")

    # if dataframe not None
    if df_all_a.empty:
        print(f"No data for glacier: {gl}")
        continue

    ax = plot_mb_by_elevation_periods(df_all_a,
                                      df_all_w,
                                      df_stakes,
                                      gl.lower(),
                                      ax=axs[i])

    area = areas_per_gl.loc[gl].Area
    if area < 0.1:
        area = np.round(area, 3)
    else:
        area = np.round(area, 1)
    axs[i].set_title(f'{gl} ({area} km2)', fontsize=fontsize, pad=2)
    axs[i].grid(alpha=0.2)
    axs[i].tick_params(labelsize=6.5, pad=2)
    axs[i].set_ylabel('')
    axs[i].set_xlabel('')
    # remove legend
    axs[i].legend().remove()

axs[5].set_ylabel('Elevation (m a.s.l.)', fontsize=fontsize)

# match the colors/linestyles used in plot_mb_by_elevation_periods
# color_annual = "#1f77b4"  # blue
# color_winter = "#ff7f0e"  # orange

# Custom handles (bands, means, and stakes)
handles = [
    # LSTM
    Patch(facecolor=color_annual, alpha=0.25, label="LSTM band (annual)"),
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (annual)"),
    Patch(facecolor=color_winter, alpha=0.25, label="LSTM band (winter)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle='-',
           label="LSTM mean (winter)"),

    # GLAMOS (mean only)
    Line2D([0], [0],
           color=color_annual,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (annual)"),
    Line2D([0], [0],
           color=color_winter,
           lw=1.2,
           linestyle=':',
           label="GLAMOS mean (winter)"),

    # Stakes means
    Line2D([0], [0],
           marker='o',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_annual,
           markeredgewidth=1.2,
           label="Stakes mean (annual)"),
    Line2D([0], [0],
           marker='s',
           linestyle='None',
           linewidth=0,
           markersize=6,
           markerfacecolor='none',
           markeredgecolor=color_winter,
           markeredgewidth=1.2,
           label="Stakes mean (winter)"),
]

fig.supxlabel('Mass balance (m w.e.)', fontsize=fontsize, y=0.06)

fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=7)

# Adjust the layout
plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

### Maps:

In [None]:
# Example usage
GLACIER_NAME = 'rhone'
# bias_gl = df[df.GLACIER == GLACIER_NAME.capitalize()].bias_gl.unique()[0]
df_lstm_two_heads_gl = df_lstm_two_heads[df_lstm_two_heads.GLACIER ==
                                         GLACIER_NAME.capitalize()]
df_nn_gl = df_nn[df_nn.GLACIER == GLACIER_NAME]

fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)

plot_scatter_comparison(axs[0],
                        df_lstm_two_heads_gl,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(LSTM two heads)")
plot_scatter_comparison(axs[1],
                        df_nn_gl,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(MLP)")

plt.tight_layout()
plt.show()

In [None]:
# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_nn = mbm_glwd_pred(PATH_PREDICTIONS_NN, GLACIER_NAME)
MBM_glwmb_nn.rename(columns={"MBM Balance": "MBM Balance MLP"}, inplace=True)

MBM_glwmb_lstm = mbm_glwd_pred(PATH_PREDICTIONS_LSTM_two_heads, GLACIER_NAME)
MBM_glwmb_lstm.rename(columns={"MBM Balance": "MBM Balance LSTM"},
                      inplace=True)

# Merge with GLAMOS data
MBM_glwmb_nn = MBM_glwmb_nn.join(GLAMOS_glwmb)
MBM_glwmb_nn = MBM_glwmb_nn.dropna()

MBM_glwmb = MBM_glwmb_nn.join(MBM_glwmb_lstm)

# Plot the data
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
MBM_glwmb.plot(ax=axs[0],
               y=['MBM Balance LSTM', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])
MBM_glwmb.plot(ax=axs[1],
               y=['MBM Balance MLP', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])

for ax in axs:
    ax.set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
    ax.set_ylabel("Mass Balance [m w.e.]", fontsize=18)
    ax.set_xlabel("Year", fontsize=18)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend(fontsize=14)

axs[0].set_title(f"{GLACIER_NAME.capitalize()} Glacier (LSTM)", fontsize=16)
axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier (MLP)", fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
for year in MBM_glwmb_nn.index:
    plot_mass_balance_comparison_annual(
        glacier_name=GLACIER_NAME,
        year=year,
        cfg=cfg,
        df_stakes=df_stakes,
        path_distributed_mb=path_distributed_MB_glamos,
        path_pred_lstm=PATH_PREDICTIONS_LSTM_two_heads,
        path_pred_nn=PATH_PREDICTIONS_NN,
        period='annual'
        # bias_correction=bias_gl
    )
