## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import joypy
from pandas.api.types import CategoricalDtype

import massbalancemachine as mbm

# Add root of repo to import MBM
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

# vois_topographical = [
#     "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
#     "millan_v", "svf"
# ]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

## Input data:

In [None]:
data_glamos = getStakesData(cfg)

# Number of winter and annual measurements:
print("Number of winter measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['winter'])
print("Number of annual measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['annual'])

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
data_test = test_set['df_X']
data_test['y'] = test_set['y']

### Feature distribution of test set:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]

STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']
feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

## LSTM:

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': True,
    'head_dropout': 0.0,
    'loss_spec': None
}

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Evaluate on test
model_filename = f"models/lstm_model_2025-10-22_two_heads_no_oggm.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

## Monthly distributions:

In [None]:
PATH_PREDICTIONS_LSTM_OOS = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/testing_LSTM/glamos_dems_LSTM_no_oggm")

PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, 'GLAMOS',
                                   'distributed_MB_grids',
                                   'MBM/glamos_dems_NN_SEB_full_OGGM')

hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
glaciers = os.listdir(PATH_PREDICTIONS_LSTM_OOS)

# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(PATH_PREDICTIONS_LSTM_OOS, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_LSTM = pd.concat(all_glacier_data, axis=0, ignore_index=True)

In [None]:
glaciers = os.listdir(PATH_PREDICTIONS_NN)

# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(PATH_PREDICTIONS_NN, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_NN = pd.concat(all_glacier_data, axis=0, ignore_index=True)

In [None]:
def pick_file_glamos(glacier, year, period="winter"):
    suffix = "ann" if period == "annual" else "win"
    base = os.path.join(cfg.dataPath, path_distributed_MB_glamos, "GLAMOS",
                        glacier)
    cand_lv95 = os.path.join(base, f"{year}_{suffix}_fix_lv95.grid")
    cand_lv03 = os.path.join(base, f"{year}_{suffix}_fix_lv03.grid")
    if os.path.exists(cand_lv95):
        return cand_lv95, "lv95"
    if os.path.exists(cand_lv03):
        return cand_lv03, "lv03"
    return None, None


def load_glamos_wgs84(glacier, year, period):
    """Load one GLAMOS .grid file and return it as an xarray in WGS84."""
    path, cs = pick_file_glamos(glacier, year, period)
    if path is None:
        return None
    meta, arr = load_grid_file(path)
    da = convert_to_xarray_geodata(arr, meta)
    if cs == "lv03":
        return transform_xarray_coords_lv03_to_wgs84(da)
    elif cs == "lv95":
        return transform_xarray_coords_lv95_to_wgs84(da)
    else:
        return None


def load_all_glamos(glacier_years, path_glamos):
    """
    Loads both annual and winter GLAMOS grids for all glaciers and years
    and returns two DataFrames: df_GLAMOS_w, df_GLAMOS_a.
    """
    all_glacier_data_w, all_glacier_data_a = [], []

    for glacier_name in tqdm(glacier_years.keys(), desc="Loading GLAMOS data"):
        glacier_path = os.path.join(path_glamos, glacier_name)
        if not os.path.isdir(glacier_path):
            continue

        years = glacier_years[glacier_name]
        all_years_w, all_years_a = [], [],

        for year in years:
            # Load winter
            ds_w = load_glamos_wgs84(glacier_name, year, period="winter")
            if ds_w is not None:
                df_w = ds_w.to_dataframe().drop(['x', 'y'],
                                                axis=1).reset_index()
                df_w = df_w[df_w.grid_data.notna()]
                df_months_w = pd.DataFrame({'apr': df_w.grid_data.values})
                df_months_w['year'] = year
                df_months_w['glacier'] = glacier_name
                all_years_w.append(df_months_w)

            # Load annual
            ds_a = load_glamos_wgs84(glacier_name, year, period="annual")
            if ds_a is not None:
                df_a = ds_a.to_dataframe().drop(['x', 'y'],
                                                axis=1).reset_index()
                df_a = df_a[df_a.grid_data.notna()]
                df_months_a = pd.DataFrame({'sep': df_a.grid_data.values})
                df_months_a['year'] = year
                df_months_a['glacier'] = glacier_name
                all_years_a.append(df_months_a)

        # Concatenate per glacier
        if all_years_w:
            df_glacier_w = pd.concat(all_years_w, ignore_index=True)
            all_glacier_data_w.append(df_glacier_w)
        if all_years_a:
            df_glacier_a = pd.concat(all_years_a, ignore_index=True)
            all_glacier_data_a.append(df_glacier_a)

    # Final full DataFrames for all glaciers
    df_GLAMOS_w = pd.concat(
        all_glacier_data_w,
        ignore_index=True) if all_glacier_data_w else pd.DataFrame()
    df_GLAMOS_a = pd.concat(
        all_glacier_data_a,
        ignore_index=True) if all_glacier_data_a else pd.DataFrame()

    return df_GLAMOS_w, df_GLAMOS_a


PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos, 'GLAMOS')
glaciers = os.listdir(PATH_GLAMOS)

glacier_years = (
    df_months_LSTM.groupby('glacier')['year'].unique().apply(sorted).to_dict())

df_GLAMOS_w, df_GLAMOS_a = load_all_glamos(glacier_years, PATH_GLAMOS)

#### Glacier-wide:

In [None]:
# get glacier-wide MB for every year
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

# Define the reference set of glacier-year pairs present in GLAMOS
valid_pairs = set(
    zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))


# Define a helper function to filter by those pairs
def filter_to_glamos(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# Apply to both model outputs
glwd_months_LSTM_filtered = filter_to_glamos(glwd_months_LSTM)
glwd_months_NN_filtered = filter_to_glamos(glwd_months_NN)
glwd_months_GLAMOS_filtered_w = glwd_months_GLAMOS_w.copy(
)  # keep as-is for clarity
glwd_months_GLAMOS_filtered_a = glwd_months_GLAMOS_a.copy()

print(len(glwd_months_GLAMOS_filtered_w), len(glwd_months_GLAMOS_a),
      len(glwd_months_NN_filtered), len(glwd_months_LSTM_filtered))

# Prepare the data for plotting
df_months_nn_long = prepare_monthly_long_df(glwd_months_LSTM_filtered,
                                            glwd_months_NN_filtered,
                                            glwd_months_GLAMOS_filtered_w,
                                            glwd_months_GLAMOS_filtered_a)
df_months_nn_long.head(2)

In [None]:
# Plot
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_lstm'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_lstm']].max()
fig = plot_monthly_joyplot(df_months_nn_long,
                           color_annual=color_annual,
                           color_winter=color_winter,
                           x_range=(np.floor(min_), np.ceil(max_)))
# save figure
fig.savefig('figures/CH_LSTM_vs_NN_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')

### Elevation bands:

#### Highest:

In [None]:
bin = 200
bins = np.arange(1200, 4500, bin)  # 200 m bins
labels = [f"{b}-{b+100}" for b in bins[:-1]]

df_months_NN_ = df_months_NN.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_months_NN_['elev_band'] = pd.cut(df_months_NN_['elevation'],
                                    bins=bins,
                                    labels=labels)

df_months_LSTM_['elev_band'] = pd.cut(df_months_LSTM_['elevation'],
                                      bins=bins,
                                      labels=labels)

# Find the max elevation per glacier
max_elev_nn = df_months_NN_.groupby('glacier')['elevation'].transform('max')

# Keep only the top ~200 m (or the bin that includes the max)
df_highest_band_nn = df_months_NN_[df_months_NN_['elevation'] >= (max_elev_nn -
                                                                  bin)]

glwd_high_NN = df_highest_band_nn.groupby(
    ['glacier', 'year']).mean(numeric_only=True).reset_index()

max_elev_lstm = df_months_LSTM_.groupby('glacier')['elevation'].transform(
    'max')

# Keep only the top ~200 m (or the bin that includes the max)
df_highest_band_lstm = df_months_LSTM_[df_months_LSTM_['elevation'] >= (
    max_elev_lstm - bin)]

glwd_high_LSTM = df_highest_band_lstm.groupby(
    ['glacier', 'year']).mean(numeric_only=True).reset_index()

# Prepare the data for plotting
df_months_nn_long = prepare_monthly_long_df(glwd_high_LSTM, glwd_high_NN)

# Plot
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_lstm'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_lstm']].max()
fig = plot_monthly_joyplot(df_months_nn_long,
                           color_annual=color_annual,
                           color_winter=color_winter,
                           x_range=(np.floor(min_), np.ceil(max_)))
# save figure
fig.savefig('figures/CH_LSTM_vs_NN_monthly_joyplot_high_elv.png',
            dpi=300,
            bbox_inches='tight')

#### Lowest:

In [None]:
bin = 250
bins = np.arange(1200, 4500, bin)  # 200 m bins
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Assign elevation bands
df_months_NN['elev_band'] = pd.cut(df_months_NN['elevation'],
                                   bins=bins,
                                   labels=labels)
df_months_LSTM['elev_band'] = pd.cut(df_months_LSTM['elevation'],
                                     bins=bins,
                                     labels=labels)

# --- NEW: Find the MIN elevation per glacier ---
min_elev_nn = df_months_NN.groupby('glacier')['elevation'].transform('min')

# Keep only the bottom ~200 m (or the bin that includes the min)
df_lowest_band_nn = df_months_NN[df_months_NN['elevation'] <= (min_elev_nn +
                                                               bin)]

glwd_low_NN = df_lowest_band_nn.groupby(
    ['glacier', 'year']).mean(numeric_only=True).reset_index()

min_elev_lstm = df_months_LSTM.groupby('glacier')['elevation'].transform('min')

# Keep only the bottom ~200 m (or the bin that includes the min)
df_lowest_band_lstm = df_months_LSTM[df_months_LSTM['elevation'] <= (
    min_elev_lstm + bin)]

glwd_low_LSTM = df_lowest_band_lstm.groupby(
    ['glacier', 'year']).mean(numeric_only=True).reset_index()

# Prepare for plotting
df_months_nn_long = prepare_monthly_long_df(glwd_low_LSTM, glwd_low_NN)

# Plot
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_lstm'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_lstm']].max()
fig = plot_monthly_joyplot(df_months_nn_long,
                           color_annual=color_annual,
                           color_winter=color_winter,
                           x_range=(np.floor(min_), np.ceil(max_)))
# save figure
fig.savefig('figures/CH_LSTM_vs_NN_monthly_joyplot_low_elv.png',
            dpi=300,
            bbox_inches='tight')

## Feature importance: