## Setting Up:

In [None]:
import sys
import os

import logging
import warnings
from datetime import datetime
from collections import defaultdict, Counter
from functools import partial

import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm.notebook import tqdm
from cmcrameri import cm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import joypy
from pandas.api.types import CategoricalDtype

import massbalancemachine as mbm

# Add root of repo to import MBM
sys.path.append(os.path.join(os.getcwd(), '../../'))

# Local modules
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

# vois_topographical = [
#     "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
#     "millan_v", "svf"
# ]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

## Input data:

In [None]:
data_glamos = getStakesData(cfg)

# Number of winter and annual measurements:
print("Number of winter measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['winter'])
print("Number of annual measurements:",
      data_glamos.groupby('PERIOD').count().YEAR.loc['annual'])

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
data_test = test_set['df_X']
data_test['y'] = test_set['y']

### Feature distribution of test set:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]

STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']
feature_columns = MONTHLY_COLS + STATIC_COLS
cfg.setFeatures(feature_columns)

## LSTM:

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    normalize_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    normalize_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.0,
    'loss_spec': None
}

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Evaluate on test
model_filename = f"models/lstm_model_2025-11-04_no_oggm_norm_y.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
def safe_item(x):
    return x.item() if x is not None else None


print("Train dataset (after make_loaders):")
print(f"  normalize_target = {ds_train_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_train_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_train_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_train_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_train_copy.y.std().item():.4f}")

print("\nTest dataset (after make_test_loader):")
print(f"  normalize_target = {ds_test_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_test_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_test_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_test_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_test_copy.y.std().item():.4f}")

## Monthly distributions:

In [None]:
# PATH_PREDICTIONS_LSTM_OOS = os.path.join(
#     cfg.dataPath, "GLAMOS", "distributed_MB_grids",
#     "MBM/testing_LSTM/LSTM_no_oggm_OOS_original_y")

PATH_PREDICTIONS_LSTM_IS = os.path.join(
    cfg.dataPath, "GLAMOS", "distributed_MB_grids",
    "MBM/testing_LSTM/LSTM_with_oggm_IS_original_y_one_head")

PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, 'GLAMOS',
                                   'distributed_MB_grids',
                                   'MBM/testing_LSTM/NN')

hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
    glacier_name="rhone",
    year=2008,
    path_pred_lstm=PATH_PREDICTIONS_NN,
    apply_smoothing_fn=apply_gaussian_filter,
)

In [None]:
fig = plot_glacier_monthly_series_lstm_sharedcmap_center0(
    glacier_name="rhone",
    year=2008,
    path_pred_lstm=PATH_PREDICTIONS_LSTM_IS,
    apply_smoothing_fn=apply_gaussian_filter,
)

In [None]:
glaciers = os.listdir(PATH_PREDICTIONS_LSTM_IS)

# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(PATH_PREDICTIONS_LSTM_IS, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_LSTM = pd.concat(all_glacier_data, axis=0, ignore_index=True)

In [None]:
glaciers = os.listdir(PATH_PREDICTIONS_NN)

# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(PATH_PREDICTIONS_NN, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_NN = pd.concat(all_glacier_data, axis=0, ignore_index=True)

In [None]:
def pick_file_glamos(glacier, year, period="winter"):
    suffix = "ann" if period == "annual" else "win"
    base = os.path.join(cfg.dataPath, path_distributed_MB_glamos, "GLAMOS",
                        glacier)
    cand_lv95 = os.path.join(base, f"{year}_{suffix}_fix_lv95.grid")
    cand_lv03 = os.path.join(base, f"{year}_{suffix}_fix_lv03.grid")
    if os.path.exists(cand_lv95):
        return cand_lv95, "lv95"
    if os.path.exists(cand_lv03):
        return cand_lv03, "lv03"
    return None, None


def load_glamos_wgs84(glacier, year, period):
    """Load one GLAMOS .grid file and return it as an xarray in WGS84."""
    path, cs = pick_file_glamos(glacier, year, period)
    if path is None:
        return None
    meta, arr = load_grid_file(path)
    da = convert_to_xarray_geodata(arr, meta)
    if cs == "lv03":
        return transform_xarray_coords_lv03_to_wgs84(da)
    elif cs == "lv95":
        return transform_xarray_coords_lv95_to_wgs84(da)
    else:
        return None


def load_all_glamos(glacier_years, path_glamos):
    """
    Loads both annual and winter GLAMOS grids for all glaciers and years,
    interpolates DEM elevation onto MB grids, and returns two DataFrames:
    df_GLAMOS_w, df_GLAMOS_a.
    """
    all_glacier_data_w, all_glacier_data_a = [], []

    for glacier_name in tqdm(glacier_years.keys(), desc="Loading GLAMOS data"):
        glacier_path = os.path.join(path_glamos, glacier_name)
        if not os.path.isdir(glacier_path):
            continue

        years = glacier_years[glacier_name]
        all_years_w, all_years_a = [], []

        for year in years:
            # --- Load DEM ---
            dem_path = (cfg.dataPath + path_GLAMOS_topo +
                        f"xr_masked_grids/{glacier_name}_{year}.zarr")
            if not os.path.exists(dem_path):
                continue
            ds_dem = xr.open_zarr(dem_path)

            # --- Load Winter MB ---
            ds_w = load_glamos_wgs84(glacier_name, year, period="winter")
            if ds_w is not None:
                masked_elev_interp = ds_dem["masked_elev"].interp_like(
                    ds_w, method="nearest")
                masked_elev_interp = masked_elev_interp.assign_coords(x=ds_w.x,
                                                                      y=ds_w.y)

                ds_merged_w = xr.merge(
                    [
                        ds_w.to_dataset(name="mb"),
                        masked_elev_interp.to_dataset(name="masked_elev"),
                    ],
                    compat="override",
                )

                df_w = ds_merged_w.to_dataframe().reset_index()
                df_w = df_w[df_w["mb"].notna() & df_w["masked_elev"].notna()]
                df_w = df_w[["x", "y", "mb", "masked_elev"]]
                df_w["year"] = year
                df_w["glacier"] = glacier_name
                df_w["period"] = "winter"
                all_years_w.append(df_w)

            # --- Load Annual MB ---
            ds_a = load_glamos_wgs84(glacier_name, year, period="annual")
            if ds_a is not None:
                masked_elev_interp = ds_dem["masked_elev"].interp_like(
                    ds_a, method="nearest")
                masked_elev_interp = masked_elev_interp.assign_coords(x=ds_a.x,
                                                                      y=ds_a.y)

                ds_merged_a = xr.merge(
                    [
                        ds_a.to_dataset(name="mb"),
                        masked_elev_interp.to_dataset(name="masked_elev"),
                    ],
                    compat="override",
                )

                df_a = ds_merged_a.to_dataframe().reset_index()
                df_a = df_a[df_a["mb"].notna() & df_a["masked_elev"].notna()]
                df_a = df_a[["x", "y", "mb", "masked_elev"]]
                df_a["year"] = year
                df_a["glacier"] = glacier_name
                df_a["period"] = "annual"
                all_years_a.append(df_a)

        # --- Concatenate per glacier ---
        if all_years_w:
            df_glacier_w = pd.concat(all_years_w, ignore_index=True)
            all_glacier_data_w.append(df_glacier_w)
        if all_years_a:
            df_glacier_a = pd.concat(all_years_a, ignore_index=True)
            all_glacier_data_a.append(df_glacier_a)

    # --- Final combined DataFrames ---
    df_GLAMOS_w = (pd.concat(all_glacier_data_w, ignore_index=True)
                   if all_glacier_data_w else pd.DataFrame())
    df_GLAMOS_a = (pd.concat(all_glacier_data_a, ignore_index=True)
                   if all_glacier_data_a else pd.DataFrame())

    # --- Drop x/y and rename elevation column ---
    if not df_GLAMOS_w.empty:
        df_GLAMOS_w = df_GLAMOS_w.drop(["x", "y", "period"],
                                       axis=1).rename(columns={
                                           "masked_elev": "elevation",
                                           "mb": "apr"
                                       })
    if not df_GLAMOS_a.empty:
        df_GLAMOS_a = df_GLAMOS_a.drop(["x", "y", "period"],
                                       axis=1).rename(columns={
                                           "masked_elev": "elevation",
                                           "mb": "sep"
                                       })

    return df_GLAMOS_w, df_GLAMOS_a


PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos, 'GLAMOS')
glaciers = os.listdir(PATH_GLAMOS)

glacier_years = (
    df_months_LSTM.groupby('glacier')['year'].unique().apply(sorted).to_dict())

df_GLAMOS_w, df_GLAMOS_a = load_all_glamos(glacier_years, PATH_GLAMOS)

#### Glacier-wide:

In [None]:
# get glacier-wide MB for every year
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

valid_pairs_NN = set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))


# Define a helper function to filter by those pairs
def filter_to_NN(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs_NN)].reset_index(drop=True)


valid_pairs = set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))


# Define a helper function to filter by those pairs
def filter_to_glamos(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- 1. Glacier-wide annual mean MB per year ---
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

# --- 2. Compute the intersection of valid glacier–year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))
    & set(zip(glwd_months_LSTM['glacier'], glwd_months_LSTM['year']))
    & set(zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))
    & set(zip(glwd_months_GLAMOS_a['glacier'], glwd_months_GLAMOS_a['year'])))


# --- 3. Helper function for filtering by glacier–year pairs ---
def filter_to_valid(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- 4. Apply consistent filtering to all datasets ---
glwd_months_LSTM_filtered = filter_to_valid(glwd_months_LSTM)
glwd_months_NN_filtered = filter_to_valid(glwd_months_NN)
glwd_months_GLAMOS_filtered_w = filter_to_valid(glwd_months_GLAMOS_w)
glwd_months_GLAMOS_filtered_a = filter_to_valid(glwd_months_GLAMOS_a)

print(
    len(glwd_months_GLAMOS_filtered_w),
    len(glwd_months_GLAMOS_filtered_a),
    len(glwd_months_NN_filtered),
    len(glwd_months_LSTM_filtered),
)

# --- 5. Prepare for plotting ---
df_months_nn_long = prepare_monthly_long_df(
    glwd_months_LSTM_filtered,
    glwd_months_NN_filtered,
    glwd_months_GLAMOS_filtered_w,
    glwd_months_GLAMOS_filtered_a,
)

df_months_nn_long.head(2)

In [None]:
# # remove bad glaciers:
# geodetic_mb = get_geodetic_MB(cfg)
# periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)
# geogl = [
#     'Schwarzbach', 'Sexrouge', 'Murtel', 'Basodino', 'Adler', 'Hohlaub',
#     'Tsanfleuron', 'Silvretta', 'Gries', 'Clariden', 'Gietro', 'Schwarzberg',
#     'Allalin', 'Findelen', 'Rhone', 'Corbassiere', 'Aletsch'
# ]
# geogl = [gl.lower() for gl in geogl]
# df_months_nn_long = df_months_nn_long[df_months_nn_long.glacier.isin(geogl)]

# gl_error = ['schwarzberg', 'gietro']

# df_months_nn_long = df_months_nn_long[~df_months_nn_long.glacier.isin(gl_error
#                                                                       )]

In [None]:
# Plot
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_lstm'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_lstm']].max()
fig = plot_monthly_joyplot(df_months_nn_long,
                           color_annual=color_annual,
                           color_winter=color_winter,
                           x_range=(np.floor(min_), np.ceil(max_)))

In [None]:
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_glamos'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_glamos']].max()
fig = plot_monthly_joyplot_single(df_months_nn_long,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)))
fig.savefig('figures/CH_LSTM_vs_NN_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')

In [None]:
emptyfolder('figures/joyplots')
for gl in df_months_nn_long.glacier.unique():
    df_gl = df_months_nn_long[df_months_nn_long.glacier == gl]
    min_, max_ = df_months_nn_long.min()[[
        'mb_nn', 'mb_glamos'
    ]].min(), df_months_nn_long.max()[['mb_nn', 'mb_glamos']].max()
    fig = plot_monthly_joyplot_single(df_gl,
                                      variable="mb_lstm",
                                      color_model=color_annual,
                                      x_range=(np.floor(min_), np.ceil(max_)),
                                      show=False)

    # save figure
    fig.savefig(f'figures/joyplots/{gl}.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

### Elevation bands:

#### Highest:

In [None]:
bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Copy datasets
df_months_NN_ = df_months_NN.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# Assign elevation bands
for df_ in [df_months_NN_, df_months_LSTM_, df_GLAMOS_a_, df_GLAMOS_w_]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


# --- Helper to extract highest-elevation band per glacier ---
def extract_highest_band(df, bin_width):
    max_elev = df.groupby("glacier")["elevation"].transform("max")
    highest_band = df[df["elevation"] >= (max_elev - bin_width)]
    return (highest_band.groupby(["glacier", "year"
                                  ]).mean(numeric_only=True).reset_index())


# --- Compute highest-elevation bands for all datasets ---
glwd_high_NN = extract_highest_band(df_months_NN_, bin)
glwd_high_LSTM = extract_highest_band(df_months_LSTM_, bin)
glwd_high_GLAMOS_a = extract_highest_band(df_GLAMOS_a_, bin)
glwd_high_GLAMOS_w = extract_highest_band(df_GLAMOS_w_, bin)

# --- Define common glacier-year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_high_NN["glacier"], glwd_high_NN["year"]))
    & set(zip(glwd_high_LSTM["glacier"], glwd_high_LSTM["year"]))
    & set(zip(glwd_high_GLAMOS_w["glacier"], glwd_high_GLAMOS_w["year"]))
    & set(zip(glwd_high_GLAMOS_a["glacier"], glwd_high_GLAMOS_a["year"])))


def filter_to_valid(df):
    return df[df[["glacier", "year"
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- Apply consistent filtering ---
glwd_high_NN_filt = filter_to_valid(glwd_high_NN)
glwd_high_LSTM_filt = filter_to_valid(glwd_high_LSTM)
glwd_high_GLAMOS_a_filt = filter_to_valid(glwd_high_GLAMOS_a)
glwd_high_GLAMOS_w_filt = filter_to_valid(glwd_high_GLAMOS_w)

print(
    len(glwd_high_NN_filt),
    len(glwd_high_LSTM_filt),
    len(glwd_high_GLAMOS_w_filt),
    len(glwd_high_GLAMOS_a_filt),
)

# --- Prepare combined long-format dataframe for plotting ---
df_months_nn_long = prepare_monthly_long_df(
    glwd_high_LSTM_filt,  # LSTM data
    glwd_high_NN_filt,  # NN data
    glwd_high_GLAMOS_w_filt,  # GLAMOS winter (April)
    glwd_high_GLAMOS_a_filt  # GLAMOS annual (September)
)

# --- Determine x-axis limits for plotting ---
min_, max_ = (
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_glamos"]].min().min(),
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_glamos"]].max().max(),
)

# --- Plot ---
# fig = plot_monthly_joyplot(
#     df_months_nn_long,
#     color_annual=color_annual,
#     color_winter=color_winter,
#     x_range=(np.floor(min_), np.ceil(max_)),
# )

fig = plot_monthly_joyplot_single(df_months_nn_long,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)))

# --- Save figure ---
fig.savefig("figures/CH_LSTM_vs_NN_GLAMOS_monthly_joyplot_high_elv.png",
            dpi=300,
            bbox_inches="tight")

#### Lowest:

In [None]:
bin = 250
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Copy to avoid modifying original DataFrames
df_months_NN_ = df_months_NN.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# Assign elevation bands to all datasets
for df_ in [df_months_NN_, df_months_LSTM_, df_GLAMOS_a_, df_GLAMOS_w_]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


# --- Helper for extracting lowest-elevation band per glacier ---
def extract_lowest_band(df, bin_width):
    min_elev = df.groupby("glacier")["elevation"].transform("min")
    lowest_band = df[df["elevation"] <= (min_elev + bin_width)]
    return (lowest_band.groupby(["glacier", "year"
                                 ]).mean(numeric_only=True).reset_index())


# --- Compute for all datasets ---
glwd_low_NN = extract_lowest_band(df_months_NN_, bin)
glwd_low_LSTM = extract_lowest_band(df_months_LSTM_, bin)
glwd_low_GLAMOS_a = extract_lowest_band(df_GLAMOS_a_, bin)
glwd_low_GLAMOS_w = extract_lowest_band(df_GLAMOS_w_, bin)

# --- Define common glacier-year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_low_NN["glacier"], glwd_low_NN["year"]))
    & set(zip(glwd_low_LSTM["glacier"], glwd_low_LSTM["year"]))
    & set(zip(glwd_low_GLAMOS_w["glacier"], glwd_low_GLAMOS_w["year"]))
    & set(zip(glwd_low_GLAMOS_a["glacier"], glwd_low_GLAMOS_a["year"])))


def filter_to_valid(df):
    return df[df[["glacier", "year"
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- Apply consistent filtering ---
glwd_low_NN_filt = filter_to_valid(glwd_low_NN)
glwd_low_LSTM_filt = filter_to_valid(glwd_low_LSTM)
glwd_low_GLAMOS_a_filt = filter_to_valid(glwd_low_GLAMOS_a)
glwd_low_GLAMOS_w_filt = filter_to_valid(glwd_low_GLAMOS_w)

print(
    len(glwd_low_NN_filt),
    len(glwd_low_LSTM_filt),
    len(glwd_low_GLAMOS_w_filt),
    len(glwd_low_GLAMOS_a_filt),
)

# --- Prepare for plotting (includes both GLAMOS winter + annual) ---
df_months_nn_long_low = prepare_monthly_long_df(glwd_low_LSTM_filt,
                                                glwd_low_NN_filt,
                                                glwd_low_GLAMOS_w_filt,
                                                glwd_low_GLAMOS_a_filt)

# --- Determine x-axis range for plotting ---
min_, max_ = (df_months_nn_long_low[["mb_nn", "mb_lstm",
                                     "mb_glamos"]].min().min(),
              df_months_nn_long_low[["mb_nn", "mb_lstm",
                                     "mb_glamos"]].max().max())

min_ = -10

# --- Plot ---
# fig = plot_monthly_joyplot(df_months_nn_long_low,
#                            color_annual=color_annual,
#                            color_winter=color_winter,
#                            x_range=(np.floor(min_), np.ceil(max_)))

fig = plot_monthly_joyplot_single(df_months_nn_long_low,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)))

# --- Save figure ---
fig.savefig("figures/CH_LSTM_vs_NN_GLAMOS_monthly_joyplot_low_elv.png",
            dpi=300,
            bbox_inches="tight")

## Feature importance:

#### Aggregated:

In [None]:
from scripts.PFI_all import permutation_feature_importance_mbm_parallel

pfi_parallel = permutation_feature_importance_mbm_parallel(
    cfg=cfg,
    custom_params=custom_params,
    model_filename=model_filename,
    df_eval=df_test,  # your eval dataframe WITH TARGETS aligned to predictions
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    ds_train=ds_train,
    train_idx=train_idx,
    target_col="POINT_BALANCE",  # <-- set your target column name
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    seed=cfg.seed,
    n_repeats=5,
    batch_size=256,
    max_workers=None,  # auto: n_cpus-1 (cap 32)
)
plt.figure(figsize=(8, max(3, 0.35 * len(pfi_parallel))))
plt.barh(pfi_parallel["feature"],
         pfi_parallel["mean_delta"],
         xerr=pfi_parallel["std_delta"])
plt.gca().invert_yaxis()
plt.title(
    f"Permutation Feature Importance (Δ{pfi_parallel['metric_name'].iloc[0]}; baseline={pfi_parallel['baseline'].iloc[0]:.3f})"
)
plt.xlabel(
    f"Increase in {pfi_parallel['metric_name'].iloc[0]} (higher = more important)"
)
plt.tight_layout()
plt.show()

### Monthly:

In [None]:
from scripts.PFI_monthly import permutation_feature_importance_mbm_monthly_parallel

month_map = {
    "aug_": 0,
    "sep_": 1,
    "oct": 2,
    "nov": 3,
    "dec": 4,
    "jan": 5,
    "feb": 6,
    "mar": 7,
    "apr": 8,
    "may": 9,
    "jun": 10,
    "jul": 11,
    "aug": 12,
    "sep": 13,
    "oct_": 14
}
df_eval = df_test.copy()
df_eval["MONTH_IDX"] = df_eval["MONTHS"].str.lower().map(month_map)

pfi_monthly = permutation_feature_importance_mbm_monthly_parallel(
    cfg,
    custom_params,
    model_filename,
    df_eval,
    MONTHLY_COLS,
    STATIC_COLS,
    ds_train,
    train_idx,
    months_head_pad,
    months_tail_pad,
    seed=cfg.seed,
    n_repeats=3,
    batch_size=256,
    denorm=True,
    max_workers=None,
)

In [None]:
# --- Month order for columns ---
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar",
    "apr", "may", "jun", "jul", "aug", "sep", "oct_"
]

# ✅ Correct column name spelling and mapping
pfi_monthly["feature_long"] = pfi_monthly["feature"].apply(
    lambda x: vois_climate_long_name.get(x, x)
)

# --- Prepare pivot tables using long names ---
piv_winter = pfi_monthly.pivot(index="feature_long", columns="month", values="mean_delta_winter")
piv_annual = pfi_monthly.pivot(index="feature_long", columns="month", values="mean_delta_annual")
piv_global = pfi_monthly.pivot(index="feature_long", columns="month", values="mean_delta_global")

# --- Reorder months for each (keeping only existing ones) ---
piv_winter = piv_winter[[m for m in month_order if m in piv_winter.columns]]
piv_annual = piv_annual[[m for m in month_order if m in piv_annual.columns]]
piv_global = piv_global[[m for m in month_order if m in piv_global.columns]]

# --- Order features by combined importance (optional) ---
feat_order = (
    pfi_monthly.groupby("feature_long")["mean_delta_global"]
    .mean()
    .sort_values(ascending=False)
    .index
)
piv_winter = piv_winter.loc[feat_order]
piv_annual = piv_annual.loc[feat_order]
piv_global = piv_global.loc[feat_order]

# --- Create figure with three side-by-side heatmaps ---
fig, axes = plt.subplots(1, 3, figsize=(20, 6), sharey=True)

sns.heatmap(
    piv_winter,
    cmap="magma",
    linewidths=0.3,
    cbar_kws={"label": "ΔRMSE (winter)"},
    ax=axes[0]
)
axes[0].set_title("Monthly PFI – Winter RMSE Δ")
axes[0].set_xlabel("Month")
axes[0].set_ylabel("Feature")

sns.heatmap(
    piv_annual,
    cmap="magma",
    linewidths=0.3,
    cbar_kws={"label": "ΔRMSE (annual)"},
    ax=axes[1]
)
axes[1].set_title("Monthly PFI – Annual RMSE Δ")
axes[1].set_xlabel("Month")
axes[1].set_ylabel("")

sns.heatmap(
    piv_global,
    cmap="magma",
    linewidths=0.3,
    cbar_kws={"label": "ΔRMSE (global)"},
    ax=axes[2]
)
axes[2].set_title("Monthly PFI – Global RMSE Δ")
axes[2].set_xlabel("Month")
axes[2].set_ylabel("")

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# --- Month order ---
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar",
    "apr", "may", "jun", "jul", "aug", "sep", "oct_"
]

# --- Map features to long names ---
pfi_monthly["feature_long"] = pfi_monthly["feature"].apply(
    lambda x: vois_climate_long_name.get(x, x)
)

# --- Prepare pivot table for global ΔRMSE ---
piv_global = pfi_monthly.pivot(
    index="feature_long", columns="month", values="mean_delta_global"
)

# --- Reorder columns (months) ---
piv_global = piv_global[[m for m in month_order if m in piv_global.columns]]

# --- Order features by average global importance (optional, makes it clean) ---
feat_order = (
    pfi_monthly.groupby("feature_long")["mean_delta_global"]
    .mean()
    .sort_values(ascending=False)
    .index
)
piv_global = piv_global.loc[feat_order]

# --- Plot single heatmap ---
plt.figure(figsize=(10, 6))
sns.heatmap(
    piv_global,
    cmap="magma",
    linewidths=0.3,
    cbar_kws={"label": "ΔRMSE (global)"}
)
plt.xlabel("Month")
plt.ylabel("Feature")
plt.title("Monthly Permutation Feature Importance – Global RMSE Δ")
plt.tight_layout()
plt.show()
