## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings
import matplotlib.gridspec as gridspec
import calendar
# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
    "millan_v", "svf"
]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf_IS.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)
existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"

existing_glaciers = set(data_monthly_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_ = data_train_Aug_
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'ELEVATION_DIFFERENCE',
    'pcsr'
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', "svf"]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:
The first option is that we want the model to always see the months from Aug of the hydr. year, even if they're just part of the padding (and not part of the loss function), we combine it with this dataset:

In [None]:
seed_all(cfg.seed)

HYDR_BEGINNING_PRESENT = False

if HYDR_BEGINNING_PRESENT:
    # Option with padded months completed (not in loss if not in measurement but still seen by model)
    # from nn_helpers.py
    ds_train = build_combined_LSTM_dataset(
        df_loss=data_train,
        df_full=data_train_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=False,
        expect_target=True)
else:
    # Option with padded months (beginning of hydr year set to 0 if not in measurement)
    df_train = data_train.copy()
    df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()
    # --- build train dataset from dataframe ---
    ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_train,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
        normalize_target=False)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
if HYDR_BEGINNING_PRESENT:
    custom_params = {
        'Fm': 9,
        'Fs': 3,
        'hidden_size': 128,
        'num_layers': 2,
        'bidirectional': False,
        'dropout': 0.1,
        'static_layers': 2,
        'static_hidden': [128, 64],
        'static_dropout': 0.1,
        'lr': 0.001,
        'weight_decay': 0.0001,
        'loss_name': 'neutral',
        'two_heads': False,
        'head_dropout': 0.1,
        'loss_spec': None
    }
    model_filename = f"models/lstm_model_2025-11-28_no_oggm_IS_original_y_past.pt"
else:
    model_filename = f"models/lstm_model_2025-12-01_no_oggm_IS_original_y.pt"
    custom_params = {
        'Fm': 9,
        'Fs': 3,
        'hidden_size': 128,
        'num_layers': 2,
        'bidirectional': False,
        'dropout': 0.1,
        'static_layers': 2,
        'static_hidden': [128, 64],
        'static_dropout': 0.1,
        'lr': 0.001,
        'weight_decay': 0.0,
        'loss_name': 'neutral',
        'two_heads': False,
        'head_dropout': 0.1,
        'loss_spec': None
    }

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=color_annual,
    color_winter=color_winter,
)

## Gradient analysis:

In [None]:
# --- month ordering ---
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]
month_order_map = {m: i for i, m in enumerate(month_order)}

In [None]:
def compute_gradient_sensitivity(
    model,
    device,
    dataloader,
    dataset,
    stake_ids,
    target_month: str,
    month_order: list,
    period: str = "annual",
):
    """
    Compute mean gradient sensitivity d(MB_target_month)/d(x_m) for a group of stakes.

    Returns:
        grads_mean_np: numpy array of shape (T, Fm)
    """

    # ---- 1. Put model in correct mode ----
    model.train()  # required for CuDNN LSTM backward
    for m in model.modules():  # but disable dropout randomness
        if isinstance(m, torch.nn.Dropout):
            m.eval()

    stake_ids = set(stake_ids)
    target_month_idx = month_order.index(target_month)

    all_keys = dataset.keys  # list of (GLACIER, YEAR, ID, PERIOD)
    i = 0
    grads_sum = None
    n_samples = 0

    # ---- 2. Iterate over dataloader ----
    for batch in dataloader:
        bs = batch["x_m"].shape[0]
        batch_keys = all_keys[i:i + bs]
        i += bs

        # Filter only samples belonging to stake_ids & correct period
        mask_idx = [
            j for j, (g, yr, mid, per) in enumerate(batch_keys)
            if mid in stake_ids and per == period
        ]
        if len(mask_idx) == 0:
            continue

        # ---- 3. Prepare tensors ----
        x_m = batch["x_m"][mask_idx].to(device).detach()
        x_s = batch["x_s"][mask_idx].to(device)
        mv = batch["mv"][mask_idx].to(device)
        mw = batch["mw"][mask_idx].to(device)
        ma = batch["ma"][mask_idx].to(device)

        x_m.requires_grad_(True)

        # ---- 4. Forward ----
        model.zero_grad(set_to_none=True)
        y_month, y_w, y_a = model(x_m, x_s, mv, mw, ma)

        # Select target month
        y_target = y_month[:, target_month_idx]

        # Scalar for backward
        loss = y_target.mean()

        # ---- 5. Backward ----
        loss.backward()

        grads = x_m.grad  # (B, T, Fm)
        grads = grads * mv.unsqueeze(-1)  # mask invalid months

        # Sum over batch dimension → (T, Fm)
        grads_batch_sum = grads.sum(dim=0).detach().cpu()

        if grads_sum is None:
            grads_sum = grads_batch_sum
        else:
            grads_sum += grads_batch_sum

        n_samples += grads.shape[0]

    if n_samples == 0:
        raise ValueError("No samples found for these stake IDs and period.")

    # ---- 6. Normalize over samples ----
    grads_mean_np = (grads_sum / n_samples).numpy()

    return grads_mean_np


def plot_sensitivity_grid_compare(
        grads_high,
        grads_low,
        month_order,
        vars_to_plot,
        MONTHLY_COLS,
        title="Gradient Sensitivity Comparison (High vs Low stakes)",
        first_valid_month=None,      # <--- NEW
        last_valid_month=None,       # <--- existing
):
    """
    Plot grid of sensitivity curves for a chosen subset of features.
    grads_high, grads_low: numpy arrays (T, Fm)
    vars_to_plot: list of feature names from MONTHLY_COLS
    first_valid_month / last_valid_month: str or None
        Example: first_valid_month="dec", last_valid_month="apr"
        Only months within this window are plotted.
    """

    # --------------------
    # Determine subset of months to plot
    # --------------------
    start_idx = 0
    end_idx = len(month_order)

    if first_valid_month is not None:
        if first_valid_month not in month_order:
            raise ValueError(f"{first_valid_month} is not in month_order.")
        start_idx = month_order.index(first_valid_month)

    if last_valid_month is not None:
        if last_valid_month not in month_order:
            raise ValueError(f"{last_valid_month} is not in month_order.")
        end_idx = month_order.index(last_valid_month) + 1

    # Slice the month window
    month_order_plot = month_order[start_idx:end_idx]

    # Slice the gradients accordingly
    grads_high_plot = grads_high[start_idx:end_idx, :]
    grads_low_plot  = grads_low[start_idx:end_idx, :]

    # --------------------
    # Setup grid
    # --------------------
    num_vars = len(vars_to_plot)
    ncols = 2
    nrows = int(np.ceil(num_vars / ncols))

    fig, axes = plt.subplots(nrows, ncols,
                             figsize=(14, nrows * 3),
                             sharex=False,
                             sharey=True)
    axes = axes.flatten()

    # --------------------
    # Plot each feature
    # --------------------
    for i, var_name in enumerate(vars_to_plot):

        idx_f = MONTHLY_COLS.index(var_name)
        ax = axes[i]

        # HIGH group
        ax.plot(
            month_order_plot,
            grads_high_plot[:, idx_f],
            marker="o",
            label="Accumulation zone (High)",
            color="tab:blue"
        )

        # LOW group
        ax.plot(
            month_order_plot,
            grads_low_plot[:, idx_f],
            marker="o",
            label="Ablation zone (Low)",
            color="tab:red"
        )

        # 0-line
        ax.axhline(0, color="black", linewidth=0.7)

        ax.set_title(vois_climate_long_name[var_name])
        ax.grid(True, alpha=0.3)
        ax.tick_params(axis='x', rotation=45)

        if i == 0:
            ax.legend()

    # Remove empty axes
    for j in range(num_vars, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(title, fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show()


### Annual:

In [None]:
# Get monthly predictions:
monthly_pred_df = model.predict_monthly_with_keys(
    device, test_dl, ds_test_copy, denorm=ds_test_copy.normalize_target)

monthly_pred_df_a = monthly_pred_df[monthly_pred_df.PERIOD == 'annual']

#### Gries:

In [None]:
# # Choose a good year
# glacier_name = 'gries'
# for year in range(2010, 2020):
#     # year = 2015
#     file_ann = f"{year}_ann_fix_lv95.grid"

#     stake_coordinates = df_stakes[(df_stakes.GLACIER == glacier_name)
#                                   & (df_stakes.YEAR == year) &
#                                   (df_stakes.PERIOD
#                                    == 'annual')].drop_duplicates()

#     lon_name = "lon"
#     lat_name = "lat"
#     grid_path_ann = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
#                                  'GLAMOS', glacier_name, file_ann)
#     metadata_ann, grid_data_ann = load_grid_file(grid_path_ann)
#     ds_glamos_ann = convert_to_xarray_geodata(grid_data_ann, metadata_ann)

#     ds_glamos_wgs84_ann = transform_xarray_coords_lv95_to_wgs84(ds_glamos_ann)
#     stake_coordinates["GLAMOS_MB"] = stake_coordinates.apply(
#         lambda row: get_predicted_mb_glamos(lon_name, lat_name, row,
#                                             ds_glamos_wgs84_ann),
#         axis=1,
#     )

#     vmin_ann = ds_glamos_wgs84_ann.min().item()
#     vmax_ann = ds_glamos_wgs84_ann.max().item()

#     (
#         cmap,
#         norm,
#     ) = get_color_maps(vmin_ann, vmax_ann)

#     fig = plt.figure(figsize=(12, 6))
#     ax = plt.subplot(1, 2, 1)
#     ds_glamos_wgs84_ann.plot(
#         ax=ax,
#         cmap=cmap,
#         norm=norm,
#     )

#     sns.scatterplot(data=stake_coordinates,
#                     x="POINT_LON",
#                     y="POINT_LAT",
#                     hue='POINT_BALANCE',
#                     s=30,
#                     legend=False,
#                     palette=cmap,
#                     hue_norm=norm,
#                     ax=ax)
#     ax.set_title(
#         f'Glacier: {glacier_name.capitalize()} | Year: {year} | Annual ')

##### Gries 2014:

In [None]:
# Choose stakes in high and low zones:
glacier_name = 'gries'
year = 2014
stake_coordinates = df_stakes[(df_stakes.GLACIER == glacier_name)
                              & (df_stakes.YEAR == year) &
                              (df_stakes.PERIOD
                               == 'annual')].drop_duplicates().sort_values(
                                   by='POINT_ELEVATION')
# Print three highest and three lowest stakes
print("Three lowest stakes:")
print(stake_coordinates.head(3).POINT_ID)
print("\nThree highest stakes:")
print(stake_coordinates.tail(3).POINT_ID)

# Get maximal end month:
print('\nATTENTION: Measurements start in:',
      calendar.month_name[stake_coordinates.MONTH_START.max()])

print('\nATTENTION: Measurements stop in:',
      calendar.month_name[stake_coordinates.MONTH_END.max()])

valid_months_ = [
    'sep_',
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
stake_point_ids_high = ['gries_493', 'gries_492', 'gries_463']
stake_point_ids_low = ['gries_3', 'gries_52', 'gries_54']

# stake_point_ids_high = ['gries_463']
# stake_point_ids_low = ['gries_54']

all_stakes_point_ids = stake_point_ids_high + stake_point_ids_low

# Get values for high and low stakes
zones_stakes = stake_coordinates[stake_coordinates.POINT_ID.isin(
    all_stakes_point_ids)]

# Get unique IDs for high and low stakes (needed for gradient comput.)
IDs_high = data_monthly[data_monthly.POINT_ID.isin(stake_point_ids_high)
                        & (data_monthly.YEAR == year) &
                        (data_monthly.PERIOD
                         == 'annual')].ID.unique().tolist()
IDs_low = data_monthly[data_monthly.POINT_ID.isin(stake_point_ids_low)
                       & (data_monthly.YEAR == year) &
                       (data_monthly.PERIOD == 'annual')].ID.unique().tolist()

# Get monthly predictions for that year and glacier
monthly_pred_gl = monthly_pred_df_a[monthly_pred_df_a.ID.isin(IDs_high +
                                                              IDs_low)]

# Filter to valid months of measurement only
monthly_pred_gl = monthly_pred_gl[monthly_pred_gl.MONTH.isin(valid_months_)]

mean_obs_mb_high = zones_stakes[zones_stakes.POINT_ID.isin(
    stake_point_ids_high)].POINT_BALANCE.mean()
mean_obs_mb_low = zones_stakes[zones_stakes.POINT_ID.isin(
    stake_point_ids_low)].POINT_BALANCE.mean()

print("Mean observed MB high stakes:", np.round(mean_obs_mb_high, 2), 'm w.e.')
print("Mean observed MB low stakes:", np.round(mean_obs_mb_low, 2), 'm w.e.')

# Get predicted monthly MB for high and low stakes
mean_pred_mb_high = test_df_preds[test_df_preds.ID.isin(IDs_high)].pred.mean()
mean_pred_mb_low = test_df_preds[test_df_preds.ID.isin(IDs_low)].pred.mean()

print("Mean predicted MB high stakes:", np.round(mean_pred_mb_high, 2),
      'm w.e.')
print("Mean predicted MB low stakes:", np.round(mean_pred_mb_low, 2), 'm w.e.')

In [None]:
# PLOT Stakes

# Load GLAMOS annual grid for that year and glacier (for plotting)
file_win = f"{year}_ann_fix_lv95.grid"
grid_path_win = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                             'GLAMOS', glacier_name, file_win)
metadata_win, grid_data_win = load_grid_file(grid_path_win)
ds_glamos_win = convert_to_xarray_geodata(grid_data_win, metadata_win)

ds_glamos_wgs84_win = transform_xarray_coords_lv95_to_wgs84(ds_glamos_win)

vmin_win = min(ds_glamos_wgs84_win.min().item(),
               zones_stakes.POINT_BALANCE.min())
vmax_win = max(ds_glamos_wgs84_win.max().item(),
               zones_stakes.POINT_BALANCE.max())

(
    cmap,
    norm,
) = get_color_maps(vmin_win, vmax_win)

# ---- Create 2×3 layout ----
fig = plt.figure(figsize=(18, 7))
gs = gridspec.GridSpec(
    2,
    3,
    width_ratios=[2.4, 1, 1],  # left big col, 2 smaller columns
    height_ratios=[1, 1])

# Left column spans both rows
ax_map = fig.add_subplot(gs[:, 0])

# Middle column (monthly)
ax_high = fig.add_subplot(gs[0, 1])
ax_low = fig.add_subplot(gs[1, 1])

# Right column (cumulative)
ax_high_cum = fig.add_subplot(gs[0, 2])
ax_low_cum = fig.add_subplot(gs[1, 2])

# ============================================================
# LEFT: MAP
# ============================================================

ds_glamos_wgs84_win.plot(ax=ax_map, cmap=cmap, norm=norm)

sns.scatterplot(
    data=zones_stakes,
    x="POINT_LON",
    y="POINT_LAT",
    hue="POINT_BALANCE",
    s=30,
    legend=False,
    palette=cmap,
    hue_norm=norm,
    ax=ax_map,
    edgecolor='black',
)

ax_map.set_title(
    f"Glacier: {glacier_name.capitalize()} | Year: {year} | Winter MB")

# ============================================================
# MIDDLE COLUMN: MONTHLY MB (High & Low)
# ============================================================

# High
mean_monthly_high = (monthly_pred_gl[monthly_pred_gl.ID.isin(
    IDs_high)].groupby("MONTH").pred_raw.mean())
mean_monthly_high = mean_monthly_high.loc[sorted(
    mean_monthly_high.index, key=lambda m: month_order_map[m])]
sns.lineplot(x=mean_monthly_high.index, y=mean_monthly_high.values, ax=ax_high)
ax_high.axhline(0, color="black", linewidth=0.7)
ax_high.set_title("Monthly MB (High Group)")
ax_high.set_xticklabels(ax_high.get_xticklabels(), rotation=45)

# Low
mean_monthly_low = (monthly_pred_gl[monthly_pred_gl.ID.isin(IDs_low)].groupby(
    "MONTH").pred_raw.mean())
mean_monthly_low = mean_monthly_low.loc[sorted(
    mean_monthly_low.index, key=lambda m: month_order_map[m])]
sns.lineplot(x=mean_monthly_low.index,
             y=mean_monthly_low.values,
             ax=ax_low,
             color="tab:red")
ax_low.axhline(0, color="black", linewidth=0.7)
ax_low.set_title("Monthly MB (Low Group)")
ax_low.set_xticklabels(ax_low.get_xticklabels(), rotation=45)

# ============================================================
# RIGHT COLUMN: CUMULATIVE MB (High & Low)
# ============================================================

# High cumulative
cum_high = mean_monthly_high.cumsum()
sns.lineplot(x=cum_high.index, y=cum_high.values, ax=ax_high_cum)
ax_high_cum.axhline(0, color="black", linewidth=0.7)
ax_high_cum.set_title("Cumulative MB (High Group)")
ax_high_cum.set_xticklabels(ax_high_cum.get_xticklabels(), rotation=45)
ax_high_cum.axhline(mean_obs_mb_high,
                    color='black',
                    linestyle='--',
                    label='Obs. mean PMB')
ax_high_cum.axhline(mean_pred_mb_high,
                    color='grey',
                    linestyle='--',
                    label='Pred. mean PMB')
ax_high_cum.legend(loc='lower right', frameon=False)

# Low cumulative
cum_low = mean_monthly_low.cumsum()
sns.lineplot(x=cum_low.index, y=cum_low.values, ax=ax_low_cum, color="tab:red")
ax_low_cum.axhline(0, color="black", linewidth=0.7)
ax_low_cum.set_title("Cumulative MB (Low Group)")
ax_low_cum.set_xticklabels(ax_low_cum.get_xticklabels(), rotation=45)
ax_low_cum.axhline(mean_obs_mb_low,
                   color='black',
                   linestyle='--',
                   label='Obs. mean PMB')
ax_low_cum.axhline(mean_pred_mb_low,
                   color='grey',
                   linestyle='--',
                   label='Pred. mean PMB')
ax_low_cum.legend(loc='lower right', frameon=False)
# ============================================================
plt.tight_layout()
plt.show()

In [None]:
target_month = "jul"

grads_high = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_high,
    target_month=target_month,
    month_order=month_order,
    period="annual",
)

grads_low = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_low,
    target_month=target_month,
    month_order=month_order,
    period="annual",
)

vars_to_plot = [
    'tp',
    't2m',
    'str',
    'slhf',
    'ssrd',
    'fal',
]

plot_sensitivity_grid_compare(
    grads_high,
    grads_low,
    month_order,
    vars_to_plot,  # list of feature names (subset)
    MONTHLY_COLS,
    title=f"Sensitivity of {target_month.capitalize()} MB (High vs Low stakes)",
    last_valid_month='aug',
    first_valid_month='oct')

In [None]:
target_month = "aug"

grads_high = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_high,
    target_month=target_month,
    month_order=month_order,
    period="annual",
)

grads_low = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_low,
    target_month=target_month,
    month_order=month_order,
    period="annual",
)

vars_to_plot = [
    'tp',
    't2m',
    'str',
    'slhf',
    'ssrd',
    'fal',
]

plot_sensitivity_grid_compare(
    grads_high,
    grads_low,
    month_order,
    vars_to_plot,  # list of feature names (subset)
    MONTHLY_COLS,
    title=f"Sensitivity of {target_month.capitalize()} MB (High vs Low stakes)",
    last_valid_month='sep',
    first_valid_month='oct')

### Winter:

In [None]:
# Get monthly predictions:
monthly_pred_df = model.predict_monthly_with_keys(
    device, test_dl, ds_test_copy, denorm=ds_test_copy.normalize_target)

monthly_pred_df_w = monthly_pred_df[monthly_pred_df.PERIOD == 'winter']

#### Gries:

In [None]:
# # Choose a good year
# glacier_name = 'gries'
# for year in range(2010, 2020):
#     # year = 2015
#     file_ann = f"{year}_win_fix_lv95.grid"

#     stake_coordinates = df_stakes[(df_stakes.GLACIER == glacier_name)
#                                   & (df_stakes.YEAR == year) &
#                                   (df_stakes.PERIOD
#                                    == 'winter')].drop_duplicates()

#     lon_name = "lon"
#     lat_name = "lat"
#     grid_path_ann = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
#                                  'GLAMOS', glacier_name, file_ann)
#     metadata_ann, grid_data_ann = load_grid_file(grid_path_ann)
#     ds_glamos_ann = convert_to_xarray_geodata(grid_data_ann, metadata_ann)

#     ds_glamos_wgs84_ann = transform_xarray_coords_lv95_to_wgs84(ds_glamos_ann)
#     stake_coordinates["GLAMOS_MB"] = stake_coordinates.apply(
#         lambda row: get_predicted_mb_glamos(lon_name, lat_name, row,
#                                             ds_glamos_wgs84_ann),
#         axis=1,
#     )

#     vmin_ann = ds_glamos_wgs84_ann.min().item()
#     vmax_ann = ds_glamos_wgs84_ann.max().item()

#     (
#         cmap,
#         norm,
#     ) = get_color_maps(vmin_ann, vmax_ann)

#     fig = plt.figure(figsize=(12, 6))
#     ax = plt.subplot(1, 2, 1)
#     ds_glamos_wgs84_ann.plot(
#         ax=ax,
#         cmap=cmap,
#         norm=norm,
#     )

#     sns.scatterplot(data=stake_coordinates,
#                     x="POINT_LON",
#                     y="POINT_LAT",
#                     hue='POINT_BALANCE',
#                     s=30,
#                     legend=False,
#                     palette=cmap,
#                     hue_norm=norm,
#                     ax=ax)
#     ax.set_title(
#         f'Glacier: {glacier_name.capitalize()} | Year: {year} | Winter ')

##### Gries 2014:

In [None]:
# Choose stakes in high and low zones:
glacier_name = 'gries'
year = 2014
stake_coordinates = df_stakes[(df_stakes.GLACIER == glacier_name)
                              & (df_stakes.YEAR == year) &
                              (df_stakes.PERIOD
                               == 'winter')].drop_duplicates().sort_values(
                                   by='POINT_ELEVATION')
# Print three highest and three lowest stakes
print("Three lowest stakes:")
print(stake_coordinates.head(3).POINT_ID)
print("\nThree highest stakes:")
print(stake_coordinates.tail(2).POINT_ID)

# Get maximal end month:
print('\nATTENTION: Measurements start in:',
      calendar.month_name[stake_coordinates.MONTH_START.max()])

print('\nATTENTION: Measurements stop in:',
      calendar.month_name[stake_coordinates.MONTH_END.max()])

valid_months_ = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
]

In [None]:
stake_point_ids_high = ['gries_492', 'gries_493']
stake_point_ids_low = ['gries_3', 'gries_52', 'gries_54']
all_stakes_point_ids = stake_point_ids_high + stake_point_ids_low

# Get values for high and low stakes
zones_stakes = stake_coordinates[stake_coordinates.POINT_ID.isin(
    all_stakes_point_ids)]

# Get unique IDs for high and low stakes (needed for gradient comput.)
IDs_high = data_monthly[data_monthly.POINT_ID.isin(stake_point_ids_high)
                        & (data_monthly.YEAR == year) &
                        (data_monthly.PERIOD
                         == 'winter')].ID.unique().tolist()
IDs_low = data_monthly[data_monthly.POINT_ID.isin(stake_point_ids_low)
                       & (data_monthly.YEAR == year) &
                       (data_monthly.PERIOD == 'winter')].ID.unique().tolist()

# Get monthly predictions for that year and glacier
monthly_pred_gl = monthly_pred_df_w[monthly_pred_df_w.ID.isin(IDs_high +
                                                              IDs_low)]

# Filter to valid months of measurement only
monthly_pred_gl = monthly_pred_gl[monthly_pred_gl.MONTH.isin(valid_months_)]

mean_obs_mb_high = zones_stakes[zones_stakes.POINT_ID.isin(
    stake_point_ids_high)].POINT_BALANCE.mean()
mean_obs_mb_low = zones_stakes[zones_stakes.POINT_ID.isin(
    stake_point_ids_low)].POINT_BALANCE.mean()

print("Mean observed MB high stakes:", np.round(mean_obs_mb_high, 2), 'm w.e.')
print("Mean observed MB low stakes:", np.round(mean_obs_mb_low, 2), 'm w.e.')

# Get predicted monthly MB for high and low stakes
mean_pred_mb_high = test_df_preds[test_df_preds.ID.isin(IDs_high)].pred.mean()
mean_pred_mb_low = test_df_preds[test_df_preds.ID.isin(IDs_low)].pred.mean()

print("Mean predicted MB high stakes:", np.round(mean_pred_mb_high, 2),
      'm w.e.')
print("Mean predicted MB low stakes:", np.round(mean_pred_mb_low, 2), 'm w.e.')

In [None]:

# PLOT Stakes

# Load GLAMOS annual grid for that year and glacier (for plotting)
file_win = f"{year}_win_fix_lv95.grid"
grid_path_win = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                             'GLAMOS', glacier_name, file_win)
metadata_win, grid_data_win = load_grid_file(grid_path_win)
ds_glamos_win = convert_to_xarray_geodata(grid_data_win, metadata_win)

ds_glamos_wgs84_win = transform_xarray_coords_lv95_to_wgs84(ds_glamos_win)

vmin_win = min(ds_glamos_wgs84_win.min().item(),
               zones_stakes.POINT_BALANCE.min())
vmax_win = max(ds_glamos_wgs84_win.max().item(),
               zones_stakes.POINT_BALANCE.max())

(
    cmap,
    norm,
) = get_color_maps(vmin_win, vmax_win)

# ---- Create 2×3 layout ----
fig = plt.figure(figsize=(18, 7))
gs = gridspec.GridSpec(
    2,
    3,
    width_ratios=[2.4, 1, 1],  # left big col, 2 smaller columns
    height_ratios=[1, 1])

# Left column spans both rows
ax_map = fig.add_subplot(gs[:, 0])

# Middle column (monthly)
ax_high = fig.add_subplot(gs[0, 1])
ax_low = fig.add_subplot(gs[1, 1])

# Right column (cumulative)
ax_high_cum = fig.add_subplot(gs[0, 2])
ax_low_cum = fig.add_subplot(gs[1, 2])

# ============================================================
# LEFT: MAP
# ============================================================

ds_glamos_wgs84_win.plot(ax=ax_map, cmap=cmap, norm=norm)

sns.scatterplot(
    data=zones_stakes,
    x="POINT_LON",
    y="POINT_LAT",
    hue="POINT_BALANCE",
    s=30,
    legend=False,
    palette=cmap,
    hue_norm=norm,
    ax=ax_map,
    edgecolor='black',
)

ax_map.set_title(
    f"Glacier: {glacier_name.capitalize()} | Year: {year} | Winter MB")

# ============================================================
# MIDDLE COLUMN: MONTHLY MB (High & Low)
# ============================================================

# High
mean_monthly_high = (monthly_pred_gl[monthly_pred_gl.ID.isin(
    IDs_high)].groupby("MONTH").pred_raw.mean())
mean_monthly_high = mean_monthly_high.loc[sorted(
    mean_monthly_high.index, key=lambda m: month_order_map[m])]
sns.lineplot(x=mean_monthly_high.index, y=mean_monthly_high.values, ax=ax_high)
ax_high.axhline(0, color="black", linewidth=0.7)
ax_high.set_title("Monthly MB (High Group)")
ax_high.set_xticklabels(ax_high.get_xticklabels(), rotation=45)

# Low
mean_monthly_low = (monthly_pred_gl[monthly_pred_gl.ID.isin(IDs_low)].groupby(
    "MONTH").pred_raw.mean())
mean_monthly_low = mean_monthly_low.loc[sorted(
    mean_monthly_low.index, key=lambda m: month_order_map[m])]
sns.lineplot(x=mean_monthly_low.index,
             y=mean_monthly_low.values,
             ax=ax_low,
             color="tab:red")
ax_low.axhline(0, color="black", linewidth=0.7)
ax_low.set_title("Monthly MB (Low Group)")
ax_low.set_xticklabels(ax_low.get_xticklabels(), rotation=45)

# ============================================================
# RIGHT COLUMN: CUMULATIVE MB (High & Low)
# ============================================================

# High cumulative
cum_high = mean_monthly_high.cumsum()
sns.lineplot(x=cum_high.index, y=cum_high.values, ax=ax_high_cum)
ax_high_cum.axhline(0, color="black", linewidth=0.7)
ax_high_cum.set_title("Cumulative MB (High Group)")
ax_high_cum.set_xticklabels(ax_high_cum.get_xticklabels(), rotation=45)
ax_high_cum.axhline(mean_obs_mb_high,
                    color='black',
                    linestyle='--',
                    label='Obs. mean PMB')
ax_high_cum.axhline(mean_pred_mb_high,
                    color='grey',
                    linestyle='--',
                    label='Pred. mean PMB')
ax_high_cum.legend(loc='lower right', frameon=False)

# Low cumulative
cum_low = mean_monthly_low.cumsum()
sns.lineplot(x=cum_low.index, y=cum_low.values, ax=ax_low_cum, color="tab:red")
ax_low_cum.axhline(0, color="black", linewidth=0.7)
ax_low_cum.set_title("Cumulative MB (Low Group)")
ax_low_cum.set_xticklabels(ax_low_cum.get_xticklabels(), rotation=45)
ax_low_cum.axhline(mean_obs_mb_low,
                   color='black',
                   linestyle='--',
                   label='Obs. mean PMB')
ax_low_cum.axhline(mean_pred_mb_low,
                   color='grey',
                   linestyle='--',
                   label='Pred. mean PMB')
ax_low_cum.legend(loc='lower right', frameon=False)
# ============================================================
plt.tight_layout()
plt.show()

In [None]:
target_month = "mar"

grads_high = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_high,
    target_month=target_month,
    month_order=month_order,
    period="winter",
)

grads_low = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_low,
    target_month=target_month,
    month_order=month_order,
    period="winter",
)

vars_to_plot = [
    'tp',
    't2m',
    'str',
    'slhf',
    'ssrd',
    'fal',
]

plot_sensitivity_grid_compare(
    grads_high,
    grads_low,
    month_order,
    vars_to_plot,  # list of feature names (subset)
    MONTHLY_COLS,
    title=f"Sensitivity of {target_month.capitalize()} MB (High vs Low stakes)",
    last_valid_month="apr",
    first_valid_month="oct",
)

In [None]:
target_month = "apr"

grads_high = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_high,
    target_month=target_month,
    month_order=month_order,
    period="winter",
)

grads_low = compute_gradient_sensitivity(
    model=model,
    device=device,
    dataloader=test_dl,
    dataset=ds_test_copy,
    stake_ids=IDs_low,
    target_month=target_month,
    month_order=month_order,
    period="winter",
)

vars_to_plot = [
    'tp',
    't2m',
    'str',
    'slhf',
    'ssrd',
    'fal',
]

plot_sensitivity_grid_compare(
    grads_high,
    grads_low,
    month_order,
    vars_to_plot,  # list of feature names (subset)
    MONTHLY_COLS,
    title=f"Sensitivity of {target_month.capitalize()} MB (High vs Low stakes)",
    last_valid_month="may",
    first_valid_month="oct",
)