## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
    "millan_v", "svf"
]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

## Cross-testing:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
    'millan_v', "svf"
]

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
len(existing_glaciers)

### Design of folds:

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold, GroupKFold

def _try_make_strata(gl_stats, q_elev, q_year):
    """Return strata labels and counts; may drop duplicate bins if ranges are narrow."""
    elev_bin = pd.qcut(gl_stats["med_elev"], q=q_elev, labels=False, duplicates="drop")
    year_bin = pd.qcut(gl_stats["med_year"], q=q_year, labels=False, duplicates="drop")
    strata = elev_bin.astype(str) + "_" + year_bin.astype(str)
    counts = strata.value_counts()
    return strata, counts

def _try_make_1d_strata(gl_stats, q_bins, on="med_elev"):
    """One-dimensional fallback stratification."""
    one_bin = pd.qcut(gl_stats[on], q=q_bins, labels=False, duplicates="drop")
    strata = one_bin.astype(str)
    counts = strata.value_counts()
    return strata, counts

def make_stratified_glacier_folds(
    data_monthly, n_splits=5, random_state=42, max_bins=5, verbose=True
):
    dm = data_monthly.copy()

    # Per-glacier summaries used for stratification
    gl_stats = (
        dm.groupby("GLACIER", as_index=False)
          .agg(med_elev=("POINT_ELEVATION", "median"),
               med_year=("YEAR", "median"),
               n_samples=("YEAR", "size"))
    )
    n_glaciers = len(gl_stats)
    if verbose:
        print(f"Total glaciers: {n_glaciers}")

    # Guard: if very few glaciers, reduce n_splits
    if n_glaciers < n_splits:
        if verbose:
            print(f"Reducing n_splits from {n_splits} to {n_glaciers} (not enough glaciers).")
        n_splits = n_glaciers

    # 2D stratification with automatic coarsening
    best = None
    for q_elev in range(max_bins, 1, -1):
        for q_year in range(max_bins, 1, -1):
            strata, counts = _try_make_strata(gl_stats, q_elev, q_year)
            if counts.min() >= n_splits:
                best = ("2d", q_elev, q_year, strata)
                if verbose:
                    print(f"Using 2D strata: elev_bins={q_elev}, year_bins={q_year}")
                break
        if best is not None:
            break

    # 1D fallback (first try elevation, then year), with coarsening
    if best is None:
        for on in ("med_elev", "med_year"):
            for q_bins in range(max_bins, 1, -1):
                strata, counts = _try_make_1d_strata(gl_stats, q_bins, on=on)
                if counts.min() >= n_splits:
                    best = ("1d", on, q_bins, strata)
                    if verbose:
                        print(f"Using 1D strata on {on}: bins={q_bins}")
                    break
            if best is not None:
                break

    folds = []
    glacier_to_fold = {}

    if best is not None:
        # We have viable strata for StratifiedGroupKFold
        if best[0] == "2d":
            _, q_elev, q_year, strata = best
        else:
            _, on, q_bins, strata = best

        sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

        X_dummy = np.zeros((len(gl_stats), 1))
        y_strata = np.asarray(strata)
        groups = gl_stats["GLACIER"].values

        for _, test_idx in sgkf.split(X_dummy, y=y_strata, groups=groups):
            test_glaciers = set(gl_stats.loc[test_idx, "GLACIER"])
            folds.append(test_glaciers)
    else:
        # Final fallback: plain GroupKFold (no stratification)
        if verbose:
            print("Falling back to GroupKFold (no valid strata with min class size ≥ n_splits).")
        gkf = GroupKFold(n_splits=n_splits)
        # dummy features; groups are glaciers
        X_dummy = np.zeros((len(gl_stats), 1))
        groups = gl_stats["GLACIER"].values
        # Build an index to glacier rows in gl_stats
        # We pass per-glacier rows through GroupKFold by expanding to sample-level via merge
        # Simpler: just split on the gl_stats rows:
        for _, test_idx in gkf.split(X_dummy, groups=groups):
            test_glaciers = set(gl_stats.loc[test_idx, "GLACIER"])
            folds.append(test_glaciers)

    # Optional: summary per fold
    if verbose:
        print("\nFold summary (glaciers, samples, elev range, year span):")
        for i, test_glaciers in enumerate(folds, start=1):
            sub = dm[dm["GLACIER"].isin(test_glaciers)]
            n_g = len(test_glaciers)
            n_s = len(sub)
            elev_min, elev_med, elev_max = sub["POINT_ELEVATION"].min(), sub["POINT_ELEVATION"].median(), sub["POINT_ELEVATION"].max()
            year_min, year_max = int(sub["YEAR"].min()), int(sub["YEAR"].max())
            print(f"Fold {i}: {n_g:2d} glaciers | {n_s:6d} samples | elev [{elev_min:.0f}, {elev_med:.0f}, {elev_max:.0f}] m | years {year_min}–{year_max}")

    glacier_to_fold = {g: i for i, s in enumerate(folds) for g in s}
    return folds, glacier_to_fold

folds, glacier_to_fold = make_stratified_glacier_folds(data_monthly, n_splits=5, random_state=42, max_bins=5, verbose=True)


In [None]:
# # Balanced GroupKFold split by GLACIER
# from sklearn.model_selection import GroupKFold

# # Optional: shuffle rows once for reproducibility of the split order
# dm = data_monthly.sample(frac=1, random_state=42).reset_index(drop=True)

# groups = dm["GLACIER"]
# gkf = GroupKFold(n_splits=5)

# folds = []  # list of sets of glacier names (test fold per split)
# glacier_to_fold = {}  # map glacier -> fold id (0..4)

# for fold_id, (train_idx, test_idx) in enumerate(gkf.split(dm, groups=groups)):
#     test_glaciers = set(dm.loc[test_idx, "GLACIER"].unique())
#     folds.append(test_glaciers)
#     for g in test_glaciers:
#         glacier_to_fold[g] = fold_id

# # Example: get train/test masks for fold 0
# fold = 0
# test_mask = dm["GLACIER"].isin(folds[fold])
# train_mask = ~test_mask

# # --- Print summary per fold ---
# print("Fold summary:\n")
# for fold_id, test_glaciers in enumerate(folds):
#     n_glaciers = len(test_glaciers)
#     n_samples = dm[dm["GLACIER"].isin(test_glaciers)].shape[0]
#     print(
#         f"Fold {fold_id+1}: {n_glaciers:2d} glaciers, {n_samples:5d} samples")

# # Optional: check total consistency
# print("\nTotal glaciers:", len(set.union(*folds)))
# print("Total samples :",
#       sum(dm[dm["GLACIER"].isin(f)].shape[0] for f in folds))

In [None]:
# # Random split:
# glaciers = np.array(sorted(data_monthly["GLACIER"].unique()))
# rng = np.random.default_rng(42)  # reproducible
# rng.shuffle(glaciers)

# folds_simple = [set(glaciers[i::5]) for i in range(5)]
# glacier_to_fold_simple = {g: i for i, s in enumerate(folds_simple) for g in s}

### Train on folds:

In [None]:
custom_params = {
    'Fm': 8,
    'Fs': 6,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.1,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': True,
    'head_dropout': 0.0,
    'loss_spec': None
}
current_date = datetime.now().strftime("%Y-%m-%d")


In [None]:
test_df_preds_all = pd.DataFrame()
for i, test_glaciers in enumerate(folds):
    print(f"\n--- Fold {i+1} / {len(folds)} ---")
    print("Test glaciers:", test_glaciers)
    seed_all(cfg.seed)

    # Define training glaciers correctly
    train_glaciers = [i for i in existing_glaciers if i not in test_glaciers]

    data_test = data_monthly[data_monthly.GLACIER.isin(test_glaciers)]

    data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]

    if len(data_train) == 0:
        print("Warning: No training data available!")
    else:
        test_perc = (len(data_test) / len(data_train)) * 100

    splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                                test_split_on='GLACIER',
                                                test_splits=test_glaciers,
                                                random_state=cfg.seed)
    # Validation and train split:
    data_train = train_set['df_X']
    data_train['y'] = train_set['y']

    data_test = test_set['df_X']
    data_test['y'] = test_set['y']

    df_train = data_train.copy()
    df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

    df_test = data_test.copy()
    df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

    # --- build train dataset from dataframe ---
    ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_train,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True)

    ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_test,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True)

    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=0.2, seed=cfg.seed)

    train_dl, val_dl = ds_train.make_loaders(
        train_idx=train_idx,
        val_idx=val_idx,
        batch_size_train=64,
        batch_size_val=128,
        seed=cfg.seed,
        fit_and_transform=
        True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
        shuffle_train=True,
        use_weighted_sampler=True  # use weighted sampler for training
    )

    # --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
    test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_test, ds_train, batch_size=128, seed=cfg.seed)

    # --- build model, resolve loss, train, reload best ---
    model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params,
                                                       device)
    loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

    model_filename = f"models/lstm_model_{current_date}_fold_{i}.pt"
    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )

    # Evaluate on test
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state)
    test_metrics, test_df_preds = model.evaluate_with_preds(
        device, test_dl, ds_test)
    test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
        'RMSE_winter']

    print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
        test_rmse_a, test_rmse_w))
    
    test_df_preds["fold"] = i
    test_df_preds_all = pd.concat([test_df_preds_all, test_df_preds],
                                  axis=0,
                                  ignore_index=True)

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds_all,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds_all,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index
test_gl_per_el = list(folds[4])
fig, axs = plt.subplots(3, 3, figsize=(25, 18), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds_all['gl_elv'] = test_df_preds_all['GLACIER'].map(gl_per_el)


axs = PlotIndividualGlacierPredVsTruth(
    test_df_preds_all,
    axs=axs,
    color_annual=color_dark_blue,
    color_winter=color_pink,
    custom_order=test_gl_per_el)

axs[3].set_ylabel("Modelled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()