## Setting Up:

In [None]:
import sys, os

sys.path.append(os.path.join(os.getcwd(),
                             '../../'))  # Add root of repo to import MBM
import csv
from functools import partial

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
import pickle
from collections import Counter
import ast

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# check:
df = dataloader_gl.data
df[(df.POINT_ID == 'adler_26') & (df.YEAR == 2006)].MONTHS

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi',
    'slope_sgi',
    'hugonnet_dhdt',
    'consensus_ice_thickness',
    'millan_v',
]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

In [None]:
key = ('adler', 2009, 11, 'winter')

# find the index of this key
try:
    idx = ds_train.keys.index(key)
except ValueError:
    raise ValueError(f"Key {key} not found in dataset.")

# fetch the corresponding sequence
sequence = ds_train[idx]
sequence['mv'], sequence['mw'], sequence['ma']

### Define & train model:

In [None]:
log_path = 'logs/lstm_two_heads_param_search_progress_2025-10-09.csv'
best_params = get_best_params_for_lstm(log_path, select_by='avg_test_loss')
best_params

In [None]:
log_path = Path(log_path)
df = pd.read_csv(log_path)
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="avg_test_loss", inplace=True)
df.head(10)

In [None]:
plot_topk_param_distributions(log_path, k=5, metric="valid_loss")

In [None]:
custom_params = {
    'Fm': len(MONTHLY_COLS),
    'Fs': len(STATIC_COLS),
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 0,
    'static_hidden': None,
    'static_dropout': None,
    'lr': 0.001,
    'weight_decay': 0.0001,
    'loss_name': 'neutral',
    'loss_spec': None,
    'two_heads': True,
    'head_dropout': 0.0
}

custom_params['two_heads'] = True
custom_params['head_dropout'] = 0.0

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_two_heads.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Evaluate on test
model_filename = 'models/lstm_model_2025-10-10_two_heads.pt'
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
from matplotlib.lines import Line2D

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(25, 18), sharex=True)

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       subplot_labels=subplot_labels,
                                       color_annual=color_dark_blue,
                                       color_winter=color_pink,
                                       custom_order=test_gl_per_el)

axs[3].set_ylabel("Modelled PMB [m w.e.]", fontsize=20)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_annual,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=color_winter,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
# all_columns = feature_columns + cfg.fieldsNotFeatures

# # Required by the dataset builder regardless of your feature list
# REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']

# # Paths
# path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
#                              'MBM/glamos_dems_LSTM_two_heads_svf')
# os.makedirs(path_save_glw, exist_ok=True)
# path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM',
#                              'xr_masked_grids')

# # Load model once
# best_state = torch.load(model_filename, map_location=device)
# model.load_state_dict(best_state)
# model.eval()

# RUN = False
# if RUN:
#     emptyfolder(path_save_glw)
#     for glacier_name in glacier_list:
#         glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos,
#                                     glacier_name)
#         if not os.path.exists(glacier_path):
#             print(f"Folder not found for {glacier_name}, skipping...")
#             continue

#         glacier_files = sorted(
#             [f for f in os.listdir(glacier_path) if glacier_name in f])

#         geodetic_range = range(np.min(periods_per_glacier[glacier_name]),
#                                np.max(periods_per_glacier[glacier_name]) + 1)

#         years = [int(f.split('_')[2].split('.')[0]) for f in glacier_files]
#         years = [y for y in years if y in geodetic_range]

#         print(
#             f"Processing {glacier_name} ({len(years)} files): {geodetic_range}"
#         )

#         for year in tqdm(years, desc=f"Processing {glacier_name}",
#                          leave=False):
#             seed_all(cfg.seed)

#             file_name = f"{glacier_name}_grid_{year}.parquet"
#             df_grid_monthly = pd.read_parquet(
#                 os.path.join(cfg.dataPath + path_glacier_grid_glamos,
#                              glacier_name, file_name)).copy()

#             df_grid_monthly.drop_duplicates(inplace=True)

#             # Keep required + feature columns; DON'T drop PERIOD/MONTHS/YEAR/ID/GLACIER
#             keep = [
#                 c for c in (set(all_columns) | set(REQUIRED))
#                 if c in df_grid_monthly.columns
#             ]
#             df_grid_monthly = df_grid_monthly[keep]

#             # --- Build winter subset (Sep–Apr) ---
#             winter_months = [
#                 'sep', 'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr'
#             ]
#             df_grid_monthly_w = (
#                 df_grid_monthly[df_grid_monthly['MONTHS'].str.lower().isin(
#                     winter_months)].copy().dropna(subset=['ID', 'MONTHS']))
#             # tag period
#             df_grid_monthly_w['PERIOD'] = 'winter'

#             # Minimal NaN clean-up for annual too
#             df_grid_monthly_a = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])

#             # --- Build ds_gl_a WITHOUT targets (annual) ---
#             ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
#                 ds_train)
#             ds_train_copy.fit_scalers(train_idx)  # fit only

#             ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
#                 df_grid_monthly_a,
#                 MONTHLY_COLS,
#                 STATIC_COLS,
#                 months_tail_pad=months_tail_pad,
#                 months_head_pad=months_head_pad,
#                 expect_target=False,
#                 show_progress=False,
#             )

#             test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
#                 ds_gl_a, ds_train_copy, seed=cfg.seed, batch_size=128)

#             # Predict annual (no metrics)
#             df_preds_a = model.predict_with_keys(device, test_gl_dl_a, ds_gl_a)

#             # Join preds back to unique cell IDs for saving (annual)
#             data_a = df_preds_a[['ID', 'pred']].set_index('ID')
#             grouped_ids_a = (df_grid_monthly_a.groupby('ID')[[
#                 'YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'
#             ]].first().merge(data_a,
#                              left_index=True,
#                              right_index=True,
#                              how='left'))
#             months_per_id_a = df_grid_monthly_a.groupby(
#                 'ID')['MONTHS'].unique()
#             grouped_ids_a = grouped_ids_a.merge(months_per_id_a,
#                                                 left_index=True,
#                                                 right_index=True)

#             grouped_ids_a.reset_index(inplace=True)
#             grouped_ids_a.sort_values(by='ID', inplace=True)

#             # Annual output
#             grouped_ids_annual = grouped_ids_a.copy()
#             grouped_ids_annual['PERIOD'] = 'annual'
#             pred_y_annual = grouped_ids_annual.drop(columns=['YEAR'],
#                                                     errors='ignore')

#             # Load geo grid once per year
#             path_glacier_dem = os.path.join(cfg.dataPath, path_xr_grids,
#                                             f"{glacier_name}_{year}.zarr")
#             ds = xr.open_dataset(path_glacier_dem)
#             geoData = mbm.geodata.GeoData(df_grid_monthly_a,
#                                           months_head_pad=months_head_pad,
#                                           months_tail_pad=months_tail_pad)

#             # Save annual
#             geoData._save_prediction(ds, pred_y_annual, glacier_name, year,
#                                      path_save_glw, "annual")

#             # ----------------------------
#             # WINTER PREDICTIONS (new)
#             # ----------------------------
#             if len(df_grid_monthly_w) == 0:
#                 print(
#                     f"[skip-winter] {glacier_name} {year}: no winter months present."
#                 )
#                 continue

#             # Build ds_gl_w WITHOUT targets (winter), reusing same fitted scalers
#             ds_gl_w = mbm.data_processing.MBSequenceDataset.from_dataframe(
#                 df_grid_monthly_w,
#                 MONTHLY_COLS,
#                 STATIC_COLS,
#                 months_tail_pad=months_tail_pad,
#                 months_head_pad=months_head_pad,
#                 expect_target=False,
#                 show_progress=False,
#             )

#             test_gl_dl_w = mbm.data_processing.MBSequenceDataset.make_test_loader(
#                 ds_gl_w, ds_train_copy, seed=cfg.seed, batch_size=128)

#             # Predict winter (no metrics)
#             df_preds_w = model.predict_with_keys(device, test_gl_dl_w, ds_gl_w)

#             # Join preds back to unique cell IDs for saving (winter)
#             data_w = df_preds_w[['ID', 'pred']].set_index('ID')
#             grouped_ids_w = (df_grid_monthly_w.groupby('ID')[[
#                 'YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'
#             ]].first().merge(data_w,
#                              left_index=True,
#                              right_index=True,
#                              how='left'))
#             months_per_id_w = df_grid_monthly_w.groupby(
#                 'ID')['MONTHS'].unique()
#             grouped_ids_w = grouped_ids_w.merge(months_per_id_w,
#                                                 left_index=True,
#                                                 right_index=True)

#             grouped_ids_w.reset_index(inplace=True)
#             grouped_ids_w.sort_values(by='ID', inplace=True)

#             # Winter output
#             grouped_ids_winter = grouped_ids_w.copy()
#             grouped_ids_winter['PERIOD'] = 'winter'
#             pred_y_winter = grouped_ids_winter.drop(columns=['YEAR'],
#                                                     errors='ignore')

#             # Save winter (reuse ds and geoData)
#             # Note: pass the winter df to GeoData if you need month filtering inside _save_prediction
#             geoData_w = mbm.geodata.GeoData(df_grid_monthly_w,
#                                             months_head_pad=months_head_pad,
#                                             months_tail_pad=months_tail_pad)
#             geoData_w._save_prediction(ds, pred_y_winter, glacier_name, year,
#                                        path_save_glw, "winter")

# # quick viz
# glacier_name = 'aletsch'
# year = 2008
# xr.open_dataset(os.path.join(path_save_glw, f'{glacier_name}/{glacier_name}_{year}_annual.zarr'))\
#   .pred_masked.plot(cmap='RdBu')

In [None]:
# === Parallelized glacier-year inference & save (CPU, Linux) ===
import os, sys, io, logging, multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from contextlib import redirect_stdout
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import xarray as xr
import torch

# ----------------- quiet main logging (optional) -----------------
os.makedirs("logs", exist_ok=True)
LOG_PATH = f"logs/predict_glaciers_{datetime.now():%Y%m%d_%H%M%S}.log"
logging.basicConfig(
    filename=LOG_PATH, level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
log = logging.getLogger("predict")

# ----------------- constants & paths -----------------
REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']
all_columns = MONTHLY_COLS + STATIC_COLS + cfg.fieldsNotFeatures
path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/glamos_dems_LSTM_two_heads')
os.makedirs(path_save_glw, exist_ok=True)
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM', 'xr_masked_grids')

# ----------------- worker init (quiet + CPU threads cap) -----------------
def _worker_init_quiet():
    # keep stderr for tqdm in main; silence worker prints
    sys.stdout = open(os.devnull, "w")
    sys.stderr = open(os.devnull, "w")
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")   # CPU only
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("MKL_NUM_THREADS", "1")
    os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
    os.environ.setdefault("NUMEXPR_MAX_THREADS", "1")
    try:
        torch.set_num_threads(1)
    except Exception:
        pass

# ----------------- per-process model cache -----------------
_MODEL = None
def _get_model_cpu(cfg, params_used, model_filename):
    """Build+load the model once per worker (cached)."""
    global _MODEL
    if _MODEL is None:
        device = torch.device("cpu")
        model = mbm.models.LSTM_MB.build_model_from_params(cfg, params_used, device, verbose=False)
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state)
        model.eval()
        _MODEL = model
    return _MODEL

# ----------------- one glacier-year task -----------------
def _process_glacier_year(args):
    glacier_name, year = args  # everything else taken from globals via fork
    try:
        # Seed for reproducibility if you wish
        seed_all(cfg.seed)

        glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos, glacier_name)
        if not os.path.exists(glacier_path):
            return ("skip", glacier_name, year, "glacier folder missing")

        file_name = f"{glacier_name}_grid_{year}.parquet"
        parquet_path = os.path.join(glacier_path, file_name)
        if not os.path.exists(parquet_path):
            return ("skip", glacier_name, year, "parquet missing")

        df_grid_monthly = pd.read_parquet(parquet_path).copy()
        df_grid_monthly.drop_duplicates(inplace=True)

        # Keep required + feature columns; preserve order
        needed = set(all_columns) | set(REQUIRED)
        keep = [c for c in df_grid_monthly.columns if c in needed]
        df_grid_monthly = df_grid_monthly[keep]

        # Build winter subset (Sep–Apr)
        winter_months = ['sep','oct','nov','dec','jan','feb','mar','apr']
        df_grid_monthly_w = (df_grid_monthly[df_grid_monthly['MONTHS'].str.lower().isin(winter_months)]
                             .copy().dropna(subset=['ID','MONTHS']))
        df_grid_monthly_w['PERIOD'] = 'winter'

        # Minimal NaN cleanup for annual
        df_grid_monthly_a = df_grid_monthly.dropna(subset=['ID','MONTHS'])

        # Fit scalers on TRAIN only (clone/train ds are global via fork)
        ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(ds_train)
        ds_train_copy.fit_scalers(train_idx)

        # Annual dataset/loader
        ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
            df_grid_monthly_a, MONTHLY_COLS, STATIC_COLS,
            months_tail_pad=months_tail_pad, months_head_pad=months_head_pad,
            expect_target=False, show_progress=False
        )
        test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
            ds_gl_a, ds_train_copy, seed=cfg.seed, batch_size=128
        )

        # Model (cached per worker)
        model = _get_model_cpu(cfg, custom_params, model_filename)
        device = torch.device("cpu")

        # Predict annual
        df_preds_a = model.predict_with_keys(device, test_gl_dl_a, ds_gl_a)

        # Aggregate annual
        data_a = df_preds_a[['ID','pred']].set_index('ID')
        meta_cols = [c for c in ['YEAR','POINT_LAT','POINT_LON','GLWD_ID'] if c in df_grid_monthly_a.columns]
        grouped_ids_a = (df_grid_monthly_a.groupby('ID')[meta_cols].first()
                         .merge(data_a, left_index=True, right_index=True, how='left'))
        months_per_id_a = df_grid_monthly_a.groupby('ID')['MONTHS'].unique()
        grouped_ids_a = grouped_ids_a.merge(months_per_id_a, left_index=True, right_index=True)
        grouped_ids_a.reset_index(inplace=True)
        grouped_ids_a.sort_values(by='ID', inplace=True)

        pred_y_annual = grouped_ids_a.copy()
        pred_y_annual['PERIOD'] = 'annual'
        pred_y_annual = pred_y_annual.drop(columns=['YEAR'], errors='ignore')

        # Load per-year DEM grid and save annual
        path_glacier_dem = os.path.join(path_xr_grids, f"{glacier_name}_{year}.zarr")
        if not os.path.exists(path_glacier_dem):
            return ("skip", glacier_name, year, "DEM zarr missing")
        ds = xr.open_dataset(path_glacier_dem)

        geoData = mbm.geodata.GeoData(df_grid_monthly_a,
                                      months_head_pad=months_head_pad,
                                      months_tail_pad=months_tail_pad)
        geoData._save_prediction(ds, pred_y_annual, glacier_name, year, path_save_glw, "annual")

        # Winter branch
        if len(df_grid_monthly_w) == 0:
            return ("ok", glacier_name, year, "no winter months")

        ds_gl_w = mbm.data_processing.MBSequenceDataset.from_dataframe(
            df_grid_monthly_w, MONTHLY_COLS, STATIC_COLS,
            months_tail_pad=months_tail_pad, months_head_pad=months_head_pad,
            expect_target=False, show_progress=False
        )
        test_gl_dl_w = mbm.data_processing.MBSequenceDataset.make_test_loader(
            ds_gl_w, ds_train_copy, seed=cfg.seed, batch_size=128
        )

        df_preds_w = model.predict_with_keys(device, test_gl_dl_w, ds_gl_w)

        data_w = df_preds_w[['ID','pred']].set_index('ID')
        grouped_ids_w = (df_grid_monthly_w.groupby('ID')[meta_cols].first()
                         .merge(data_w, left_index=True, right_index=True, how='left'))
        months_per_id_w = df_grid_monthly_w.groupby('ID')['MONTHS'].unique()
        grouped_ids_w = grouped_ids_w.merge(months_per_id_w, left_index=True, right_index=True)
        grouped_ids_w.reset_index(inplace=True)
        grouped_ids_w.sort_values(by='ID', inplace=True)

        pred_y_winter = grouped_ids_w.copy()
        pred_y_winter['PERIOD'] = 'winter'
        pred_y_winter = pred_y_winter.drop(columns=['YEAR'], errors='ignore')

        geoData_w = mbm.geodata.GeoData(df_grid_monthly_w,
                                        months_head_pad=months_head_pad,
                                        months_tail_pad=months_tail_pad)
        geoData_w._save_prediction(ds, pred_y_winter, glacier_name, year, path_save_glw, "winter")

        return ("ok", glacier_name, year, "")

    except Exception as e:
        return ("err", glacier_name, year, str(e))

# ----------------- build tasks -----------------
tasks = []
for glacier_name in glacier_list:
    glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos, glacier_name)
    if not os.path.exists(glacier_path):
        continue
    glacier_files = sorted([f for f in os.listdir(glacier_path) if glacier_name in f and f.endswith(".parquet")])
    if not glacier_files:
        continue
    geodetic_range = range(np.min(periods_per_glacier[glacier_name]),
                           np.max(periods_per_glacier[glacier_name]) + 1)
    years = [int(f.split('_')[2].split('.')[0]) for f in glacier_files]
    years = [y for y in years if y in geodetic_range]
    for y in years:
        tasks.append((glacier_name, y))

# ----------------- run in parallel (quiet stdout, keep tqdm) -----------------
class _Devnull(io.StringIO):
    def write(self, *args, **kwargs): return 0

ctx = mp.get_context("fork")  # Linux
max_workers = min(max(1, (os.cpu_count() or 2) - 1), 32)

with redirect_stdout(_Devnull()):  # keep stderr so tqdm is visible
    ok = skip = err = 0
    with ProcessPoolExecutor(max_workers=max_workers,
                             initializer=_worker_init_quiet,
                             mp_context=ctx) as ex:
        futures = [ex.submit(_process_glacier_year, t) for t in tasks]
        for fut in tqdm(as_completed(futures), total=len(futures),
                        desc=f"Predicting ({max_workers} workers)"):
            status, g, y, msg = fut.result()
            if status == "ok":
                ok += 1
                if msg:  # no winter months, non-fatal
                    log.info(f"OK {g} {y}: {msg}")
            elif status == "skip":
                skip += 1
                log.warning(f"SKIP {g} {y}: {msg}")
            else:
                err += 1
                log.error(f"ERR {g} {y}: {msg}")

log.info(f"SUMMARY: ok={ok} skip={skip} err={err} total={len(tasks)}")
print(f"Done. Logs → {LOG_PATH}")

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                                  df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)


## Permutation importance:

In [None]:
import copy
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

def permutation_importance_lstm(
    model,
    device,
    ds_test,                 # MBSequenceDataset (already tensorized)
    MONTHLY_COLS,
    STATIC_COLS,
    *,
    ds_train_with_scalers=None,   # <<< NEW: provide a train ds that has y_mean/y_std set
    batch_size: int = 128,
    n_repeats: int = 10,
    seed: int = 42,
    metric: str = "RMSE_mean",    # "RMSE_annual" | "RMSE_winter" | "RMSE_mean"
    num_workers: int = 0,
    pin_memory: bool = False,
):
    """
    Permutation importance for an LSTM with sequence + static features.

    IMPORTANT:
      - `evaluate_with_preds` needs ds.y_mean / ds.y_std to de-normalize targets.
      - Pass `ds_train_with_scalers` (a dataset that already ran fit_scalers(...))
        so we can inject y_mean/y_std into ds_test and its permuted copies.

    Returns
    -------
    df_imp, baseline_score
    """
    rng = np.random.default_rng(seed)
    model.eval()

    # --- ensure target scalers exist on the base test dataset ---
    if ds_train_with_scalers is not None:
        # copy target scalers (tensors) if missing
        if getattr(ds_test, "y_mean", None) is None:
            ds_test.y_mean = ds_train_with_scalers.y_mean.clone()
        if getattr(ds_test, "y_std", None) is None:
            ds_test.y_std = ds_train_with_scalers.y_std.clone()

    # --- helper to compute scalar score from metrics dict ---
    def score_from_metrics(metrics: dict) -> float:
        if metric == "RMSE_annual":
            return float(metrics["RMSE_annual"])
        elif metric == "RMSE_winter":
            return float(metrics["RMSE_winter"])
        elif metric == "RMSE_mean":
            vals = [float(v) for k, v in metrics.items()
                    if k.startswith("RMSE_") and np.isfinite(v)]
            return float(np.mean(vals)) if len(vals) else float("nan")
        else:
            raise ValueError(f"Unknown metric '{metric}'")

    # --- baseline ---
    base_dl = DataLoader(
        ds_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    base_metrics, _ = model.evaluate_with_preds(device, base_dl, ds_test)
    baseline_score = score_from_metrics(base_metrics)

    results = []

    # ---- STATIC FEATURES ----
    for feat in STATIC_COLS:
        col_idx = STATIC_COLS.index(feat)
        deltas = []
        for _ in range(n_repeats):
            # shallow-copy dataset object; clone only tensors we mutate
            ds_perm = copy.copy(ds_test)
            Xs_perm = ds_test.Xs.clone()

            # inject target scalers for this permuted copy as well
            if ds_train_with_scalers is not None:
                ds_perm.y_mean = ds_train_with_scalers.y_mean.clone()
                ds_perm.y_std  = ds_train_with_scalers.y_std.clone()

            col = Xs_perm[:, col_idx].cpu().numpy()
            Xs_perm[:, col_idx] = torch.from_numpy(rng.permutation(col)).to(Xs_perm.dtype)
            ds_perm.Xs = Xs_perm

            dl = DataLoader(ds_perm, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
            m, _ = model.evaluate_with_preds(device, dl, ds_perm)
            deltas.append(score_from_metrics(m) - baseline_score)

        results.append({
            "feature": feat,
            "group": "static",
            "metric": metric,
            "mean_importance": float(np.mean(deltas)),
            "std_importance": float(np.std(deltas, ddof=1) if len(deltas) > 1 else 0.0),
        })

    # ---- MONTHLY FEATURES (permute full 12-mo sequences across samples) ----
    for feat in MONTHLY_COLS:
        fidx = MONTHLY_COLS.index(feat)
        deltas = []
        for _ in range(n_repeats):
            ds_perm = copy.copy(ds_test)
            Xm_perm = ds_test.Xm.clone()  # (N, 12, Fm)

            if ds_train_with_scalers is not None:
                ds_perm.y_mean = ds_train_with_scalers.y_mean.clone()
                ds_perm.y_std  = ds_train_with_scalers.y_std.clone()

            # sequences shape (N, 12)
            seqs = Xm_perm[:, :, fidx].cpu().numpy()
            seqs_perm = rng.permutation(seqs)  # permute across samples
            Xm_perm[:, :, fidx] = torch.from_numpy(seqs_perm).to(Xm_perm.dtype)
            ds_perm.Xm = Xm_perm

            dl = DataLoader(ds_perm, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=pin_memory)
            m, _ = model.evaluate_with_preds(device, dl, ds_perm)
            deltas.append(score_from_metrics(m) - baseline_score)

        results.append({
            "feature": feat,
            "group": "monthly",
            "metric": metric,
            "mean_importance": float(np.mean(deltas)),
            "std_importance": float(np.std(deltas, ddof=1) if len(deltas) > 1 else 0.0),
        })

    df_imp = (pd.DataFrame(results)
              .sort_values("mean_importance", ascending=False)
              .reset_index(drop=True))
    return df_imp, baseline_score


In [None]:
# Fit scalers on TRAIN (as you do for inference)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(ds_train)
ds_train_copy.fit_scalers(train_idx)

# IMPORTANT: ds_test must already be *tensorized* (has Xm/Xs),
# and should be transformed with the same scalers you use at inference.
# If you normally construct a test loader like this:
#   test_dl = MBSequenceDataset.make_test_loader(ds_test, ds_train_copy, ...)
# you can run that once up-front to ensure ds_test.Xm/Xs are standardized.

df_imp, baseline = permutation_importance_lstm(
    model,
    device=torch.device("cpu" if not torch.cuda.is_available() else "cuda"),
    ds_test=ds_test,                         # the tensorized/standardized test dataset
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    ds_train_with_scalers=ds_train_copy,     # <<< gives us y_mean/y_std
    batch_size=128,
    n_repeats=10,
    seed=cfg.seed,
    metric="RMSE_mean",
)

print(f"Baseline {df_imp.metric.iloc[0]}: {baseline:.4f}")
print(df_imp.head(20))


In [None]:
plot_permutation_importance(df_imp, top_n=20)