## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

colors = get_cmap_hex(cm.batlow, 10)
color_annual = "#c51b7d"
color_winter = colors[0]

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

data_test_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    TEST_GLACIERS)]
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]

splits_Aug_, test_set_Aug_, train_set_Aug_ = get_CV_splits(
    dataloader_gl_Aug_,
    test_split_on='GLACIER',
    test_splits=TEST_GLACIERS,
    random_state=cfg.seed)

# # Validation and train split:
data_train_Aug_ = train_set_Aug_['df_X']
data_train_Aug_['y'] = train_set_Aug_['y']
data_test_Aug_ = test_set_Aug_['df_X']
data_test_Aug_['y'] = test_set_Aug_['y']

seed_all(cfg.seed)

df_train_Aug_ = data_train_Aug_.copy()
df_train_Aug_['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test_Aug_ = data_test_Aug_.copy()
df_test_Aug_['PERIOD'] = df_test_Aug_['PERIOD'].str.strip().str.lower()

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
def build_combined_dataset(
    df_loss,
    df_full,
    monthly_cols,
    static_cols,
    months_head_pad,
    months_tail_pad,
    normalize_target=True,
    expect_target=True,
):
    # Clean copies
    df_loss = df_loss.copy()
    df_full = df_full.copy()
    df_loss["PERIOD"] = df_loss["PERIOD"].str.lower().str.strip()
    df_full["PERIOD"] = df_full["PERIOD"].str.lower().str.strip()

    # --------------------------------------
    # STEP 1 — Remove POINT_BALANCE from df_full
    # --------------------------------------
    df_full_clean = df_full.drop(columns=["POINT_BALANCE", "y"],
                                 errors="ignore")

    # --------------------------------------
    # STEP 2 — Keep only the POINT_BALANCE information from df_loss
    # --------------------------------------
    df_loss_reduced = df_loss[[
        "GLACIER", "YEAR", "ID", "PERIOD", "MONTHS", "POINT_BALANCE"
    ]].copy()

    # --------------------------------------
    # STEP 3 — Merge
    # padded months will have POINT_BALANCE = NaN
    # --------------------------------------
    df_combined = df_full_clean.merge(
        df_loss_reduced,
        on=["GLACIER", "YEAR", "ID", "PERIOD", "MONTHS"],
        how="left")

    # --------------------------------------
    # STEP 4 — Build dataset
    # --------------------------------------
    ds = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df=df_combined,
        monthly_cols=monthly_cols,
        static_cols=static_cols,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        expect_target=expect_target,
        normalize_target=normalize_target,
    )

    return ds

### Build LSTM dataloaders:

In [None]:
ds_train_full = build_combined_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=True,
                                       expect_target=True)

ds_test_full = build_combined_dataset(df_loss=data_test,
                                      df_full=data_test_Aug_,
                                      monthly_cols=MONTHLY_COLS,
                                      static_cols=STATIC_COLS,
                                      months_head_pad=months_head_pad_Aug_,
                                      months_tail_pad=months_tail_pad_Aug_,
                                      normalize_target=True,
                                      expect_target=True)

In [None]:
def inspect_sample(ds, idx, month_labels=None):
    """
    Visualize a dataset sample:
    - Monthly climate inputs (Xm)
    - Masks mv, mw, ma
    - Show which months count toward the loss
    """
    x_m = ds.Xm[idx].numpy()
    mv = ds.mv[idx].numpy()
    mw = ds.mw[idx].numpy()
    ma = ds.ma[idx].numpy()
    key = ds.keys[idx]

    if month_labels is None:
        # Infer from dataset order: (MONTHS => pos_map) → sorted by pos
        month_labels = [f"m{i}" for i in range(x_m.shape[0])]

    df = pd.DataFrame({
        "Month": month_labels,
        "mv(valid)": mv,
        "mw(winter)": mw,
        "ma(annual)": ma
    })

    print("=== Sample info ===")
    print("Key:", key)
    print("Target y:", float(ds.y[idx].numpy()))
    print()
    print(df)

    # Plot masks
    fig, ax = plt.subplots(figsize=(10, 3))
    ax.plot(mv, label="mv (valid inputs)")
    ax.plot(mw, label="mw (winter MB loss mask)")
    ax.plot(ma, label="ma (annual MB loss mask)")
    ax.set_title(
        f"Masks for sample {idx} (GLACIER={key[0]}, YEAR={key[1]}, PERIOD={key[3]})"
    )
    ax.set_xticks(range(len(month_labels)))
    ax.set_xticklabels(month_labels, rotation=45)
    ax.legend()
    plt.show()

In [None]:
month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    months_head_pad_Aug_, months_tail_pad_Aug_)
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

inspect_sample(ds_train_full, 0, month_labels=month_order)
inspect_sample(ds_train_full, 10, month_labels=month_order)
inspect_sample(ds_train_full, 150, month_labels=month_order)

In [None]:
# mv = 1   → the input exists (so the LSTM sees the climate features)
# ma = 0   → the month is excluded from the annual loss
# mw = 0   → (also excluded from winter loss)


def inspect_padded_months(ds, n_samples=5, tol_zero=1e-6):
    """
    Inspect climate inputs and masks for padded vs true months.
    Shows: valid mask (mv), winter mask (mw), annual mask (ma),
           mean of features, whether NaN, whether all-zero.
    """
    Xm = ds.Xm.detach().cpu().numpy()  # (B, T, Fm)
    mv = ds.mv.detach().cpu().numpy()  # (B, T)
    mw = ds.mw.detach().cpu().numpy()  # (B, T)
    ma = ds.ma.detach().cpu().numpy()  # (B, T)

    B, T, F = Xm.shape
    idxs = random.sample(range(B), min(n_samples, B))

    print(f"Inspecting {len(idxs)} random samples")
    for i in idxs:
        print("\n────────────────────────────────────")
        print(f"Sample {i} — {ds.keys[i]}")
        print("MONTH | mv | mw | ma | mean(X) | has_NaN | all_zero")
        for t in range(T):
            x = Xm[i, t, :]
            mean_val = float(np.nanmean(x))
            nan_mask = np.isnan(x).any()
            zero_mask = np.all(np.abs(x) < tol_zero)
            print(
                f"{t:2d} | {int(mv[i, t])}  | {int(mw[i, t])}  | {int(ma[i, t])}  | "
                f"{mean_val:7.3f} |  {nan_mask}  |  {zero_mask}")


inspect_padded_months(ds_train_full)

### Define & train model:

In [None]:
log_path = 'logs/lstm_one_heads_param_search_progress_no_oggm_2025-11-06.csv'
best_params = get_best_params_for_lstm(log_path, select_by='valid_loss')
print(best_params)
df = pd.read_csv(log_path)
# df = df[df['test_rmse_w'] < 0.4]
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="valid_loss", inplace=True)
df.head(10)

In [None]:
custom_params = {
    'lr': 0.001,
    'weight_decay': 0.0001,
    'hidden_size': 128,
    'num_layers': 2,
    'dropout': 0.2,
    'head_dropout': 0.0,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'Fm': 9,
    'Fs': 3,
    'bidirectional': False,
    'loss_name': 'neutral',
    'loss_spec': None,
    'two_heads': False
}

# custom_params = best_params

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_no_oggm_norm_y_past.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train_full) ---
seed_all(cfg.seed)
ds_train_full_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_full)
ds_test_full_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_full)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_full), val_ratio=0.2, seed=cfg.seed)

train_dl, val_dl = ds_train_full_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test_full and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_full_copy, ds_train_full_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
seed_all(cfg.seed)
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_full_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
def safe_item(x):
    return x.item() if x is not None else None


print("Train dataset (after make_loaders):")
print(f"  normalize_target = {ds_train_full_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_train_full_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_train_full_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_train_full_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_train_full_copy.y.std().item():.4f}")

print("\nTest dataset (after make_test_loader):")
print(f"  normalize_target = {ds_test_full_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_test_full_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_test_full_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_test_full_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_test_full_copy.y.std().item():.4f}")

In [None]:
y_train = ds_train_full_copy.y.cpu().numpy()
y_test = ds_test_full_copy.y.cpu().numpy()

plt.figure(figsize=(6, 4))
plt.hist(y_train, bins=30, alpha=0.6, label="Train", density=True)
plt.hist(y_test, bins=30, alpha=0.6, label="Test", density=True)
plt.axvline(y_train.mean(),
            color='k',
            linestyle='--',
            lw=1,
            label='mean (train)')
plt.xlabel("Target (y)")
plt.ylabel("Density")
plt.title(
    f"Target distribution ({'normalized' if ds_train_full_copy.normalize_target else 'physical'} units)"
)
plt.legend()
plt.show()

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=color_annual,
    color_winter=color_winter,
)

In [None]:
# Areas (with clariden alias fix)
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(2, 4, figsize=(32, 15), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       subplot_labels=subplot_labels,
                                       color_annual=color_annual,
                                       color_winter=color_winter,
                                       custom_order=test_gl_per_el,
                                       gl_area=gl_area)

axs[3].set_ylabel("Modeled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


denorm = ds_train_full_copy.normalize_target  # only denormalize if dataset used normalization
print("Denormalize:", denorm)

glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
from scripts.parallel_mb import MBJobConfig, run_glacier_mb

path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/LSTM_no_oggm_OOS_norm_y_past')

PATH_GLACIER_GRIDS = 'GLAMOS/topo/gridded_topo_inputs/GLAMOS_grid_Aug_/'

RUN = True
if RUN:
    job = MBJobConfig(
        cfg=cfg,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        fields_not_features=cfg.fieldsNotFeatures,
        model_filename=model_filename,
        custom_params=custom_params,
        ds_train=ds_train_full,
        train_idx=train_idx,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        data_path=cfg.dataPath,
        path_glacier_grid_glamos=PATH_GLACIER_GRIDS,
        path_xr_grids=os.path.join(cfg.dataPath, 'GLAMOS', 'topo',
                                   'GLAMOS_DEM', 'xr_masked_grids'),
        path_save_glw=path_save_glw,
        seed=cfg.seed,
        max_workers=16,  # or an int
        cpu_only=True,
        ONLY_GEODETIC=False,
        denorm=ds_train_full_copy.normalize_target,
        save_monthly=True)

    # 3) Run
    summary = run_glacier_mb(job, glacier_list, periods_per_glacier)
    print("SUMMARY:", summary)

In [None]:
fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
    glacier_name="rhone",
    year=2008,
    path_pred_lstm=path_save_glw,
    apply_smoothing_fn=apply_gaussian_filter,
)

In [None]:
glacier_name = 'rhone'
year = 2008
# open xarray
xr.open_dataset(
    path_save_glw +
    f'/{glacier_name}/{glacier_name}_{year}_annual.zarr').pred_masked.plot()

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                                  df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(
    df_all_nn,
    bins=[0, 1, 5, 10, 100, np.inf],
    labels=['<1', '1-5', '5–10', '>10', '>100'],
    max_bins=4)