# Make predictions for whole CA
## Setting up:

In [None]:
# --- System & utilities ---
import os
import sys
import re
import csv
import ast
import math
import traceback
import itertools
import random
import pickle
import logging
import warnings
from datetime import datetime
from functools import partial
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
import copy

# Add repo root for MBM imports
sys.path.append(os.path.join(os.getcwd(), "../../"))

# --- Data science stack ---
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from cmcrameri import cm

# --- Machine learning / DL ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch.helper import SliceDataset
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint

# --- Cartography / plotting ---
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

# --- Custom MBM modules ---
import massbalancemachine as mbm
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Warnings & autoreload (notebook) ---
warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2

# --- Configuration ---
cfg = mbm.SwitzerlandConfig()
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

# --- CUDA / device ---
if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# RGI Ids:
# Read glacier ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
rgi_df.loc['rhone']

In [None]:
# glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

# gdirs, rgidf = initialize_oggm_glacier_directories(
#     cfg,
#     rgi_region="11",
#     rgi_version="62",
#     base_url=
#     "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L1-L2_files/2025.6/elev_bands_w_data/",
#     log_level='WARNING',
#     task_list=None,
# )

# # Save OGGM xr for all needed glaciers in RGI region 11.6:
# df_missing = export_oggm_grids(cfg, gdirs)

# path_rgi = cfg.dataPath + 'GLAMOS/RGI/nsidc0770_11.rgi60.CentralEurope/11_rgi60_CentralEurope.shp'

# # load RGI shapefile
# gdf = gpd.read_file(path_rgi)
# # reproject to a local equal-area projection (example: EPSG:3035 for Europe)
# gdf_proj = gdf.to_crs(3035)
# gdf_proj.rename(columns={"RGIId": "rgi_id"}, inplace=True)
# # gdf_proj.set_index('rgi_id', inplace=True)
# gdf_proj["area_m2"] = gdf_proj.geometry.area
# gdf_proj["area_km2"] = gdf_proj["area_m2"] / 1e6

# df_missing = df_missing.merge(gdf_proj[['area_km2', 'rgi_id']], on="rgi_id")

# # total glacier area
# total_area = gdf_proj["area_km2"].sum()

# # explode the list of missing vars into rows (one var per row)
# df_exploded = df_missing.explode("missing_vars")

# # 1) COUNT: number of glaciers missing each variable
# counts_missing_per_var = (
#     df_exploded.groupby("missing_vars")["rgi_id"].nunique().sort_values(
#         ascending=False))

# # 2) TOTAL % AREA with ANY missing var
# total_missing_area_km2 = df_missing["area_km2"].sum()
# total_missing_area_pct = (total_missing_area_km2 / total_area) * 100

# print(f"Total glacier area with ANY missing variable: "
#       f"{total_missing_area_km2:,.2f} km² "
#       f"({total_missing_area_pct:.2f}%)")

# # Optional: also show % area per variable (kept from your earlier logic)
# area_missing_per_var = (
#     df_exploded.groupby("missing_vars")["area_km2"].sum().sort_values(
#         ascending=False))
# perc_missing_per_var = (area_missing_per_var / total_area) * 100

# print("\n% of total glacier area missing per variable:")
# for var, pct in perc_missing_per_var.items():
#     print(f"  - {var}: {pct:.2f}%")

# # ---- barplot: number of glaciers missing each variable ----
# plt.figure(figsize=(7, 4))
# plt.bar(counts_missing_per_var.index, counts_missing_per_var.values)
# plt.xlabel("Missing variable")
# plt.ylabel("Number of glaciers")
# plt.title("Count of glaciers missing each variable")
# plt.tight_layout()
# plt.show()

## Stakes data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}

# Climate columns
vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]
# Topographical columns
vois_topographical = [
    "aspect", "slope", "hugonnet_dhdt", "consensus_ice_thickness", "millan_v",
    "topo", "svf"
]

RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_CA.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## Extrapolate to all glaciers:

In [None]:
# ---- fixed order to avoid accidental shuffles ----
VARS_ORDER = ['hugonnet_dhdt', 'consensus_ice_thickness', 'millan_v']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]

# one canonical list for the partial OGGM combos you trained
PARTIAL_COMBOS = [
    ('hugonnet_dhdt', ),
    ('consensus_ice_thickness', ),
    ('millan_v', ),
    ('hugonnet_dhdt', 'consensus_ice_thickness'),
    ('hugonnet_dhdt', 'millan_v'),
    ('consensus_ice_thickness', 'millan_v'),
]


def combo_key(tup):
    ordered = [v for v in VARS_ORDER if v in tup]
    return "__".join(ordered)  # e.g. "hugonnet_dhdt__millan_v"


def combo_key_from_tuple(tup):
    """('hugonnet_dhdt','millan_v')->'hugonnet_dhdt__millan_v'; ()->'simple'; all 3->'full'"""
    if not tup:
        return "simple"
    ordered = [v for v in VARS_ORDER if v in tup]
    if ordered == VARS_ORDER:
        return "full"
    return "__".join(ordered)


def make_params(Fm, STATIC_COLS):
    # Fs derived from STATIC_COLS length
    return {
        'Fm': Fm,  # number of monthly features
        'Fs': len(STATIC_COLS),  # number of static features
        'hidden_size': 128,
        'num_layers': 1,
        'bidirectional': False,
        'dropout': 0.0,
        'static_layers': 0,
        'static_hidden': None,
        'static_dropout': None,
        'lr': 0.001,
        'weight_decay': 0.0,
        'loss_name': 'neutral',
        'loss_spec': None,
        'two_heads': True,
        'head_dropout': 0.0
    }


# ---- params for all combos (simple/partials/full) built once ----
Fm = len(MONTHLY_COLS)

STATIC_SIMPLE = ['aspect', 'slope', 'svf']  # no OGGM vars
STATIC_FULL = ['aspect', 'slope', 'svf', *VARS_ORDER]  # all OGGM vars

params_simple_model = make_params(Fm, STATIC_SIMPLE)
params_full_model = make_params(Fm, STATIC_FULL)

# partials
params_by_key = {}
for combo in PARTIAL_COMBOS:
    ordered = [v for v in VARS_ORDER if v in combo]
    static_cols = ['aspect', 'slope', 'svf', *ordered]
    key = combo_key(combo)
    params_by_key[key] = make_params(Fm, static_cols)

params_hugonnet_only = copy.deepcopy(params_by_key['hugonnet_dhdt'])
params_consensus_only = copy.deepcopy(params_by_key['consensus_ice_thickness'])
params_millan_only = copy.deepcopy(params_by_key['millan_v'])
params_hugonnet_consensus = copy.deepcopy(
    params_by_key['hugonnet_dhdt__consensus_ice_thickness'])
params_hugonnet_millan = copy.deepcopy(
    params_by_key['hugonnet_dhdt__millan_v'])
params_consensus_millan = copy.deepcopy(
    params_by_key['consensus_ice_thickness__millan_v'])

# single authoritative mapping (includes convenient aliases "simple"/"full")
PARAMS_BY_COMBO = {
    "simple": params_simple_model,
    "hugonnet_dhdt": params_hugonnet_only,
    "consensus_ice_thickness": params_consensus_only,
    "millan_v": params_millan_only,
    "hugonnet_dhdt__consensus_ice_thickness": params_hugonnet_consensus,
    "hugonnet_dhdt__millan_v": params_hugonnet_millan,
    "consensus_ice_thickness__millan_v": params_consensus_millan,
    "full": params_full_model,
}

# ---- config / constants ----
# Include simple and full for completeness
ALL_COMBOS = [()] + PARTIAL_COMBOS + [tuple(VARS_ORDER)]

# ---- normalize input dataframes once ----
df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()  # optional, if you also want test ds per combo
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# ---- builders ----
DS_TRAIN_BY_COMBO = {}
SPLIT_IDXS_BY_COMBO = {}  # (train_idx, val_idx)
STATIC_COLS_BY_COMBO = {}

seed_all(cfg.seed)

for combo in ALL_COMBOS:
    key = combo_key_from_tuple(combo)

    # Build static columns in stable order
    if key == "simple":
        STATIC_COLS = STATIC_SIMPLE
    elif key == "full":
        STATIC_COLS = STATIC_FULL
    else:
        ordered = [v for v in VARS_ORDER if v in combo]
        STATIC_COLS = ['aspect', 'slope', 'svf', *ordered]

    STATIC_COLS_BY_COMBO[key] = STATIC_COLS

    # --- build train dataset from dataframe ---
    ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_train,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
    )
    DS_TRAIN_BY_COMBO[key] = ds_train

    # keep per-combo split (use your preferred val_ratio/seed)
    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=0.2, seed=cfg.seed)
    SPLIT_IDXS_BY_COMBO[key] = (train_idx, val_idx)

# convenience names
ds_train_simple = DS_TRAIN_BY_COMBO['simple']
ds_train_full = DS_TRAIN_BY_COMBO['full']
ds_train_hugonnet_only = DS_TRAIN_BY_COMBO['hugonnet_dhdt']
ds_train_consensus_only = DS_TRAIN_BY_COMBO['consensus_ice_thickness']
ds_train_millan_only = DS_TRAIN_BY_COMBO['millan_v']
ds_train_hugonnet_consensus = DS_TRAIN_BY_COMBO[
    'hugonnet_dhdt__consensus_ice_thickness']
ds_train_hugonnet_millan = DS_TRAIN_BY_COMBO['hugonnet_dhdt__millan_v']
ds_train_consensus_millan = DS_TRAIN_BY_COMBO[
    'consensus_ice_thickness__millan_v']

train_idx_simple, val_idx_simple = SPLIT_IDXS_BY_COMBO['simple']
train_idx_full, val_idx_full = SPLIT_IDXS_BY_COMBO['full']

# List *exactly* the combos you trained (keys must match your saved filenames/params)
TRAINED_COMBOS = [
    (),  # simple (no OGGM)
    *PARTIAL_COMBOS,
    tuple(VARS_ORDER),  # full (all OGGM)
]

# Build mapping keys
COMBO_KEYS = [combo_key_from_tuple(c) for c in TRAINED_COMBOS]

# ---- checkpoints for each combo (update paths to match your training) ----
MODEL_PATHS = {
    "simple": "models/lstm_model_2025-10-09_CA_simple.pt",
    "hugonnet_dhdt": "models/lstm_model_2025-10-09_CA_hugonnet_dhdt.pt",
    "consensus_ice_thickness":
    "models/lstm_model_2025-10-09_CA_consensus_ice_thickness.pt",
    "millan_v": "models/lstm_model_2025-10-09_CA_millan_v.pt",
    "hugonnet_dhdt__consensus_ice_thickness":
    "models/lstm_model_2025-10-09_CA_hugonnet_dhdt_consensus_ice_thickness.pt",
    "hugonnet_dhdt__millan_v":
    "models/lstm_model_2025-10-09_CA_hugonnet_dhdt_millan_v.pt",
    "consensus_ice_thickness__millan_v":
    "models/lstm_model_2025-10-09_CA_consensus_ice_thickness_millan_v.pt",
    "full": "models/lstm_model_2025-10-09_CA_full.pt",
}

# ---- params per combo (Fs must reflect STATIC_COLS length!) ----
# Use the single authoritative PARAMS_BY_COMBO from above

# ---- training datasets (for scalers) ----
TRAIN_DS_BY_COMBO = {
    "simple": ds_train_simple,
    "hugonnet_dhdt": ds_train_hugonnet_only,
    "consensus_ice_thickness": ds_train_consensus_only,
    "millan_v": ds_train_millan_only,
    "hugonnet_dhdt__consensus_ice_thickness": ds_train_hugonnet_consensus,
    "hugonnet_dhdt__millan_v": ds_train_hugonnet_millan,
    "consensus_ice_thickness__millan_v": ds_train_consensus_millan,
    "full": ds_train_full,
}

In [None]:
# Cache for loaded models
_MODEL_CACHE = {}

def get_model_by_combo(combo_key_str: str, cfg, device):
    """Generalized model loader with cache."""
    if combo_key_str in _MODEL_CACHE:
        return _MODEL_CACHE[combo_key_str]

    if combo_key_str not in MODEL_PATHS:
        raise ValueError(f"No model path for combo '{combo_key_str}'")
    if combo_key_str not in PARAMS_BY_COMBO:
        raise ValueError(f"No params for combo '{combo_key_str}'")

    ckpt_path = MODEL_PATHS[combo_key_str]
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(
            f"Checkpoint not found for '{combo_key_str}': {ckpt_path}")

    params = PARAMS_BY_COMBO[combo_key_str]
    model = mbm.models.LSTM_MB.build_model_from_params(cfg,
                                                       params,
                                                       device,
                                                       verbose=False)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    _MODEL_CACHE[combo_key_str] = model
    return model


def detect_available_combo_key(df: pd.DataFrame) -> str:
    """
    Return the most specific trained combo that matches columns present in df.
    Priority: full > any 2-var combo > any 1-var combo > simple.
    """
    present_set = set(c for c in VARS_ORDER if c in df.columns)
    # try exact matches by decreasing size
    candidates = sorted(TRAINED_COMBOS, key=lambda t: len(t), reverse=True)
    for tup in candidates:
        if set(tup).issubset(present_set):
            return combo_key_from_tuple(tup)
    return "simple"

In [None]:
# # Cache for loaded models
# _MODEL_CACHE = {}

# def get_model_by_combo(combo_key_str: str, cfg, device):
#     """Generalized model loader with cache."""
#     if combo_key_str in _MODEL_CACHE:
#         return _MODEL_CACHE[combo_key_str]

#     if combo_key_str not in MODEL_PATHS:
#         raise ValueError(f"No model path for combo '{combo_key_str}'")
#     if combo_key_str not in PARAMS_BY_COMBO:
#         raise ValueError(f"No params for combo '{combo_key_str}'")

#     ckpt_path = MODEL_PATHS[combo_key_str]
#     if not os.path.exists(ckpt_path):
#         raise FileNotFoundError(
#             f"Checkpoint not found for '{combo_key_str}': {ckpt_path}")

#     params = PARAMS_BY_COMBO[combo_key_str]
#     model = mbm.models.LSTM_MB.build_model_from_params(cfg,
#                                                        params,
#                                                        device,
#                                                        verbose=False)
#     state = torch.load(ckpt_path, map_location=device)
#     model.load_state_dict(state)
#     model.eval()
#     _MODEL_CACHE[combo_key_str] = model
#     return model


# def detect_available_combo_key(df: pd.DataFrame) -> str:
#     """
#     Return the most specific trained combo that matches columns present in df.
#     Priority: full > any 2-var combo > any 1-var combo > simple.
#     """
#     present_set = set(c for c in VARS_ORDER if c in df.columns)
#     # try exact matches by decreasing size
#     candidates = sorted(TRAINED_COMBOS, key=lambda t: len(t), reverse=True)
#     for tup in candidates:
#         if set(tup).issubset(present_set):
#             return combo_key_from_tuple(tup)
#     return "simple"

# # Required by the dataset builder regardless of your feature list
# REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']

# current_date = datetime.now().strftime("%Y-%m-%d")
# path_rgi_alps = os.path.join(cfg.dataPath, 'GLAMOS/topo/gridded_topo_inputs/RGI_v6_11_svf/')
# rgi_id_list = os.listdir(path_rgi_alps)

# # Safe rename helper (wire up if your files have alternate column names)
# def safe_rename(df, mapping):
#     present = {k: v for k, v in mapping.items() if k in df.columns}
#     return df.rename(columns=present) if present else df

# # Robust year regex: RGI60-11.01238_grid_2003.parquet
# YEAR_RE = re.compile(r"_grid_(\d{4})\.parquet$")

# # --- main run loop (replaces your has_oggm/simple branch) ---
# RUN = True
# if RUN:
#     os.makedirs("logs", exist_ok=True)  # ensure folder exists
#     output_file = os.path.join("logs", f"glacier_mean_MB_{current_date}.csv")
#     with open(output_file, 'w') as f:
#         f.write("Index,RGIId,Year,Mean_MB\n")

#     index_counter = 0

#     for rgi_gl in tqdm(rgi_id_list):
#         seed_all(cfg.seed)
#         glacier_path = os.path.join(path_rgi_alps, rgi_gl)
#         if not os.path.exists(glacier_path):
#             print(f"Folder not found for {rgi_gl}, skipping...")
#             continue

#         # collect available years
#         years = []
#         for fname in os.listdir(glacier_path):
#             if not (fname.endswith(".parquet") and rgi_gl in fname):
#                 continue
#             m = YEAR_RE.search(fname)
#             if m:
#                 years.append(int(m.group(1)))
#         years = sorted(set(years))
#         if not years:
#             print(f"No parquet years for {rgi_gl}, skipping...")
#             continue

#         # clear model cache per glacier (optional)
#         _MODEL_CACHE.clear()

#         for year in years:
#             file_name = f"{rgi_gl}_grid_{year}.parquet"
#             file_path = os.path.join(glacier_path, file_name)

#             try:
#                 df_grid_monthly = pd.read_parquet(file_path)
#             except Exception as e:
#                 print(f"[{rgi_gl} {year}] read error: {e}")
#                 continue

#             df_grid_monthly.drop_duplicates(inplace=True)

#             # Normalize required fields
#             if 'PERIOD' in df_grid_monthly.columns:
#                 df_grid_monthly['PERIOD'] = df_grid_monthly['PERIOD'].astype(str).str.strip().str.lower()

#             # --- choose the best available trained combo for this file ---
#             selected_combo_key = detect_available_combo_key(df_grid_monthly)
#             # build STATIC_COLS in stable order matching training schema
#             if selected_combo_key == "simple":
#                 oggm_cols = []
#             elif selected_combo_key == "full":
#                 oggm_cols = VARS_ORDER[:]  # ['hugonnet_dhdt','consensus_ice_thickness','millan_v']
#             else:
#                 raw = selected_combo_key.split("__")
#                 oggm_cols = [v for v in VARS_ORDER if v in raw]

#             STATIC_COLS = ['aspect', 'slope', 'svf'] + oggm_cols
#             feature_columns = MONTHLY_COLS + STATIC_COLS
#             all_columns = feature_columns + cfg.fieldsNotFeatures

#             # retain only required + feature columns (preserve original order)
#             need = set(all_columns) | set(REQUIRED)
#             keep = [c for c in df_grid_monthly.columns if c in need]
#             df_grid_monthly = df_grid_monthly[keep]

#             # required checks
#             missing_req = [c for c in REQUIRED if c not in df_grid_monthly.columns]
#             if missing_req:
#                 print(f"[{rgi_gl} {year}] missing required cols: {missing_req}, skipping...")
#                 continue

#             # minimal NaN cleanup
#             df_grid_monthly_a = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])
#             if df_grid_monthly_a.empty:
#                 print(f"[{rgi_gl} {year}] empty after NaN cleanup, skipping...")
#                 continue

#             # --- build inference dataset with the combo's schema ---
#             ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
#                 df_grid_monthly_a,
#                 MONTHLY_COLS,
#                 STATIC_COLS,
#                 months_tail_pad=months_tail_pad,
#                 months_head_pad=months_head_pad,
#                 expect_target=False,
#                 show_progress=False,
#             )

#             # bring the *matching* training scalers for this combo
#             if selected_combo_key not in DS_TRAIN_BY_COMBO:
#                 print(f"[{rgi_gl} {year}] no training dataset registered for combo '{selected_combo_key}', skipping...")
#                 continue

#             ds_train_for_combo = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
#                 DS_TRAIN_BY_COMBO[selected_combo_key]
#             )

#             if selected_combo_key not in SPLIT_IDXS_BY_COMBO:
#                 print(f"[{rgi_gl} {year}] no split indices registered for combo '{selected_combo_key}', skipping...")
#                 continue

#             train_idx_for_combo, _ = SPLIT_IDXS_BY_COMBO[selected_combo_key]
#             ds_train_for_combo.fit_scalers(train_idx_for_combo)

#             test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
#                 ds_gl_a, ds_train_for_combo, seed=cfg.seed, batch_size=128
#             )

#             # --- load + predict with the appropriate model for this combo ---
#             try:
#                 model = get_model_by_combo(selected_combo_key, cfg, device)
#             except Exception as e:
#                 print(f"[{rgi_gl} {year}] model load error for '{selected_combo_key}': {e}")
#                 continue

#             try:
#                 df_preds_a = model.predict_with_keys(device, test_gl_dl_a, ds_gl_a)
#             except Exception as e:
#                 print(f"[{rgi_gl} {year}] prediction error for '{selected_combo_key}': {e}")
#                 continue

#             # --- aggregate to annual glacier mean and write ---
#             data_a = df_preds_a[['ID', 'pred']].set_index('ID')
#             meta_cols = [c for c in ['YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'] if c in df_grid_monthly_a.columns]

#             grouped_ids_a = (
#                 df_grid_monthly_a.groupby('ID')[meta_cols].first()
#                 .merge(data_a, left_index=True, right_index=True, how='left')
#             )
#             months_per_id_a = df_grid_monthly_a.groupby('ID')['MONTHS'].unique()
#             grouped_ids_a = grouped_ids_a.merge(months_per_id_a, left_index=True, right_index=True)

#             grouped_ids_a.reset_index(inplace=True)
#             grouped_ids_a.sort_values(by='ID', inplace=True)

#             pred_y_annual = grouped_ids_a.copy()
#             pred_y_annual['PERIOD'] = 'annual'
#             pred_y_annual = pred_y_annual.drop(columns=['YEAR'], errors='ignore')

#             mean_MB = pred_y_annual['pred'].mean()

#             with open(output_file, 'a') as f:
#                 f.write(f"{index_counter},{rgi_gl},{year},{mean_MB:.4f}\n")

#             index_counter += 1

#### Parallel version:

In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import os, re
import torch
import pandas as pd
from datetime import datetime

# -------- utilities for workers --------

def _worker_init(seed: int):
    # Avoid BLAS thread storms in multi-proc
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("MKL_NUM_THREADS", "1")
    torch.set_num_threads(1)
    seed_all(seed)

def _process_one_file(args):
    """
    Args is a dict to keep signature pickle-friendly.
    Returns: (status, payload) where status in {"ok","skip","err"}
    payload for "ok": (rgi_gl, year, mean_MB)
    """
    (rgi_gl, year, file_path, cfg, device_str) = (
        args["rgi_gl"], args["year"], args["file_path"], args["cfg"], args["device_str"]
    )

    try:
        df_grid_monthly = pd.read_parquet(file_path)
    except Exception as e:
        return ("err", f"[{rgi_gl} {year}] read error: {e}")

    df_grid_monthly.drop_duplicates(inplace=True)
    if 'PERIOD' in df_grid_monthly.columns:
        df_grid_monthly['PERIOD'] = (
            df_grid_monthly['PERIOD'].astype(str).str.strip().str.lower()
        )

    # Decide combo
    selected_combo_key = detect_available_combo_key(df_grid_monthly)
    if selected_combo_key == "simple":
        oggm_cols = []
    elif selected_combo_key == "full":
        oggm_cols = VARS_ORDER[:]
    else:
        raw = selected_combo_key.split("__")
        oggm_cols = [v for v in VARS_ORDER if v in raw]

    STATIC_COLS = ['aspect', 'slope', 'svf'] + oggm_cols
    feature_columns = MONTHLY_COLS + STATIC_COLS
    all_columns = feature_columns + cfg.fieldsNotFeatures

    REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']
    need = set(all_columns) | set(REQUIRED)
    keep = [c for c in df_grid_monthly.columns if c in need]
    df_grid_monthly = df_grid_monthly[keep]

    # Required checks
    missing_req = [c for c in REQUIRED if c not in df_grid_monthly.columns]
    if missing_req:
        return ("skip", f"[{rgi_gl} {year}] missing required cols: {missing_req}")

    df_grid_monthly_a = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])
    if df_grid_monthly_a.empty:
        return ("skip", f"[{rgi_gl} {year}] empty after NaN cleanup")

    # Build inference dataset
    ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_grid_monthly_a,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=False,
        show_progress=False,
    )

    # Bring matching training scalers for this combo
    if selected_combo_key not in DS_TRAIN_BY_COMBO:
        return ("skip", f"[{rgi_gl} {year}] no train DS for '{selected_combo_key}'")

    ds_train_for_combo = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        DS_TRAIN_BY_COMBO[selected_combo_key]
    )

    if selected_combo_key not in SPLIT_IDXS_BY_COMBO:
        return ("skip", f"[{rgi_gl} {year}] no split idxs for '{selected_combo_key}'")

    train_idx_for_combo, _ = SPLIT_IDXS_BY_COMBO[selected_combo_key]
    ds_train_for_combo.fit_scalers(train_idx_for_combo)

    test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_gl_a, ds_train_for_combo, seed=cfg.seed, batch_size=128
    )

    # Load model in worker
    device = torch.device(device_str)
    try:
        model = get_model_by_combo(selected_combo_key, cfg, device)
    except Exception as e:
        return ("err", f"[{rgi_gl} {year}] model load '{selected_combo_key}': {e}")

    # Predict
    try:
        df_preds_a = model.predict_with_keys(device, test_gl_dl_a, ds_gl_a)
    except Exception as e:
        return ("err", f"[{rgi_gl} {year}] predict '{selected_combo_key}': {e}")

    # Aggregate mean
    data_a = df_preds_a[['ID', 'pred']].set_index('ID')
    meta_cols = [c for c in ['YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'] if c in df_grid_monthly_a.columns]
    grouped_ids_a = (
        df_grid_monthly_a.groupby('ID')[meta_cols].first()
        .merge(data_a, left_index=True, right_index=True, how='left')
    )
    months_per_id_a = df_grid_monthly_a.groupby('ID')['MONTHS'].unique()
    grouped_ids_a = grouped_ids_a.merge(months_per_id_a, left_index=True, right_index=True)

    grouped_ids_a.reset_index(inplace=True)
    grouped_ids_a.sort_values(by='ID', inplace=True)

    pred_y_annual = grouped_ids_a.copy()
    pred_y_annual['PERIOD'] = 'annual'
    pred_y_annual = pred_y_annual.drop(columns=['YEAR'], errors='ignore')

    mean_MB = float(pred_y_annual['pred'].mean())
    return ("ok", (rgi_gl, year, mean_MB))

In [None]:
current_date = datetime.now().strftime("%Y-%m-%d")
path_rgi_alps = os.path.join(cfg.dataPath, 'GLAMOS/topo/gridded_topo_inputs/RGI_v6_11_svf/')
rgi_id_list = os.listdir(path_rgi_alps)

# --- build flat task list ---
tasks = []
YEAR_RE = re.compile(r"_grid_(\d{4})\.parquet$")
for rgi_gl in rgi_id_list:
    glacier_path = os.path.join(path_rgi_alps, rgi_gl)
    if not os.path.exists(glacier_path):
        continue
    cand_years = []
    for fname in os.listdir(glacier_path):
        if fname.endswith(".parquet") and rgi_gl in fname:
            m = YEAR_RE.search(fname)
            if m:
                cand_years.append((int(m.group(1)), os.path.join(glacier_path, fname)))
    for year, file_path in sorted(cand_years):
        tasks.append({
            "rgi_gl": rgi_gl,
            "year": year,
            "file_path": file_path,
            "cfg": cfg,                     # must be picklable
            #"device_str": str(device),      # e.g. "cpu" for CPU-parallel
            "device_str": 'cpu'
        })

# --- run in parallel (CPU) ---
os.makedirs("logs", exist_ok=True)
current_date = datetime.now().strftime("%Y-%m-%d")
output_file = os.path.join("logs", f"glacier_mean_MB_{current_date}.csv")

# write header once
with open(output_file, "w") as f:
    f.write("Index,RGIId,Year,Mean_MB\n")

index_counter = 0
max_workers = min( max(1, os.cpu_count() - 1), 32 )  # cap so we don't overload
with ProcessPoolExecutor(
    max_workers=max_workers,
    initializer=_worker_init,
    initargs=(cfg.seed,),
) as ex:
    futures = [ex.submit(_process_one_file, t) for t in tasks]
    for fut in as_completed(futures):
        status, payload = fut.result()
        if status == "ok":
            rgi_gl, year, mean_MB = payload
            with open(output_file, "a") as f:
                f.write(f"{index_counter},{rgi_gl},{year},{mean_MB:.4f}\n")
            index_counter += 1
        else:
            # "skip" or "err"
            print(payload)

## Results:

In [None]:
#path_rgi = cfg.dataPath+'GLAMOS/RGI/RGI2000-v7.0-G-11_central_europe/RGI2000-v7.0-G-11_central_europe.shp'
path_rgi = cfg.dataPath + 'GLAMOS/RGI/nsidc0770_11.rgi60.CentralEurope/11_rgi60_CentralEurope.shp'

# load RGI shapefile
gdf = gpd.read_file(path_rgi)

# check CRS
print(gdf.crs)

# reproject to a local equal-area projection (example: EPSG:3035 for Europe)
gdf_proj = gdf.to_crs(3035)
gdf_proj.set_index('RGIId', inplace=True, drop=True)
gdf_proj["area_m2"] = gdf_proj.geometry.area
gdf_proj["area_km2"] = gdf_proj["area_m2"] / 1e6

In [None]:
# open output file
output_df = pd.read_csv("logs/glacier_mean_MB_2025-10-08.csv").drop(['Index'],
                                                                    axis=1)

output_df['area_gl'] = output_df['RGIId'].map(
    lambda x: gdf_proj.loc[x, 'area_km2'])

# yearly_mean_mb_CA = output_df.groupby('Year',
#                                       as_index=False).agg({'Mean_MB': 'mean'})
# yearly_cum_mb_CA = output_df.groupby('Year',
#                                      as_index=False).agg({'Mean_MB': 'sum'})
# yearly_cum_mb_CA['Cum_MB'] = yearly_cum_mb_CA['Mean_MB'].cumsum()
# yearly_cum_mb_CA['Mean_MB'] = yearly_mean_mb_CA['Mean_MB']
# # yearly_cum_mb_CA['Mean_MB'] = yearly_cum_mb_CA['Mean_MB'] / total_area
# yearly_cum_mb_CA.head()

df = output_df.copy()

# annual change per glacier in Gt
df["annual_change_gt"] = (df["Mean_MB"] * df["area_gl"]) / 1e9

# total annual change in Gt (sum across glaciers)
annual_gt = df.groupby("Year")["annual_change_gt"].sum().reset_index(
    name="Annual_MB_Gt")

# cumulative MB in Gt
annual_gt["Cumulative_MB_Gt"] = annual_gt["Annual_MB_Gt"].cumsum()

# compute weighted mean MB per year
yearly_weighted = (output_df.groupby("Year").apply(lambda g: (g["Mean_MB"] * g[
    "area_gl"]).sum() / g["area_gl"].sum()).reset_index(name="Weighted_MB"))

print(yearly_weighted.head())

In [None]:
output_df

In [None]:
glambie_df = pd.read_csv('glambie_values.csv')
date_columns = [
    'central_europe_dates', 'central_europe_start_dates',
    'central_europe_end_dates'
]

glambie_df[date_columns] = glambie_df[date_columns].apply(
    lambda x: x.round() - 1)
glambie_df.head()

In [None]:
# --- plotting ---
fig, axs = plt.subplots(1, 2, figsize=(15, 6), sharey=True)

# --------------------
# Left: LSTM results
# --------------------
ax1 = axs[0]
years = yearly_weighted['Year']

# barplot: annual weighted MB (m w.e.)
ax1.bar(years,
        yearly_weighted['Weighted_MB'],
        color="skyblue",
        label="Area-weighted annual MB")
ax1.set_ylabel("Annual MB (m w.e.)", color="skyblue")

# lineplot: cumulative MB in Gt (secondary axis)
ax2 = ax1.twinx()
ax2.plot(annual_gt['Year'],
         annual_gt['Cumulative_MB_Gt'],
         color="red",
         marker="o",
         label="Cumulative MB")
ax2.set_ylabel("Cumulative MB (Gt)", color="red")

ax1.set_title("Central Alps annual MB (LSTM)")
ax1.legend(loc="upper left")
ax2.legend(loc="upper right")

# --------------------
# Right: GLAMBIE results
# --------------------
ax3 = axs[1]

# annual MB (bars)
ax3.bar(glambie_df['central_europe_end_dates'],
        glambie_df['central_europe_annual_change_mwe'],
        color="lightgreen",
        label="Annual MB (GLAMBIE)")
ax3.set_ylabel("Annual MB (m w.e.)", color="lightgreen")

# cumulative MB (line, secondary axis)
ax4 = ax3.twinx()
ax4.plot(glambie_df['central_europe_dates'],
         glambie_df['central_europe_cumulative_change_gt'],
         color="darkgreen",
         marker="s",
         label="Cumulative MB (GLAMBIE)")
ax4.set_ylabel("Cumulative MB (Gt)", color="darkgreen")

ax3.set_title("Central Europe MB (GLAMBIE)")
ax3.legend(loc="upper left")
ax4.legend(loc="upper right")

# --------------------
# Formatting
# --------------------
for ax in axs:
    ax.tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()


In [None]:
# make sure both datasets use the same x-axis type
years_lstm = yearly_weighted['Year']

years_glambie = glambie_df['central_europe_end_dates']

fig, ax = plt.subplots(figsize=(12, 6))

# bar width
width = 0.4

# LSTM bars (slightly shifted left)
ax.bar(years_lstm - 0.2,
       yearly_weighted['Weighted_MB'],
       width=width,
       color="skyblue",
       label="LSTM Annual MB")

# GLAMBIE bars (slightly shifted right)
ax.bar(years_glambie + 0.2,
       glambie_df['central_europe_annual_change_mwe'],
       width=width,
       color="lightgreen",
       label="GLAMBIE Annual MB")

# formatting
ax.set_ylabel("Annual MB (m w.e.)")
ax.set_title("Annual Mass Balance: LSTM vs GLAMBIE")
ax.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()