## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
    "millan_v", "svf"
]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf_IS.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)
existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"

existing_glaciers = set(data_monthly_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_ = data_train_Aug_
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'ELEVATION_DIFFERENCE',
    'pcsr'
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', "svf"]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
# seed_all(cfg.seed)

# df_train = data_train.copy()
# df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

# # --- build train dataset from dataframe ---
# ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
#     df_train,
#     MONTHLY_COLS,
#     STATIC_COLS,
#     months_tail_pad=months_tail_pad,
#     months_head_pad=months_head_pad,
#     expect_target=True,
#     normalize_target=True)

# train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
#     len(ds_train), val_ratio=0.2, seed=cfg.seed)

In [None]:
seed_all(cfg.seed)
ds_train_combd = build_combined_LSTM_dataset(
    df_loss=data_train,
    df_full=data_train_Aug_,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=months_head_pad_Aug_,
    months_tail_pad=months_tail_pad_Aug_,
    normalize_target=True,
    expect_target=True)
train_idx_combd, val_idx_combd = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_combd), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
# log_path = 'logs/lstm_one_head_param_search_progress_no_oggm_IS_2025-11-04.csv'
# best_params = get_best_params_for_lstm(log_path, select_by='test_rmse_a')
# custom_params = best_params
# custom_params['two_heads'] = False

# # --- build model, resolve loss, train, reload best ---
# current_date = datetime.now().strftime("%Y-%m-%d")
# model_filename = f"models/lstm_model_2025-12-01_no_oggm_IS_norm_y.pt"

# # --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
# ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
#     ds_train)

# train_dl, val_dl = ds_train_copy.make_loaders(
#     train_idx=train_idx,
#     val_idx=val_idx,
#     batch_size_train=64,
#     batch_size_val=128,
#     seed=cfg.seed,
#     fit_and_transform=
#     True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
#     shuffle_train=True,
#     use_weighted_sampler=True  # use weighted sampler for training
# )

# # --- build model, resolve loss, train, reload best ---
# model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
# loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
#     ds_train)

# test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
#     ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# # Evaluate on test
# state = torch.load(model_filename, map_location=device)
# model.load_state_dict(state)
# test_metrics, test_df_preds = model.evaluate_with_preds(
#     device, test_dl, ds_test_copy)
# test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
#     'RMSE_winter']

# print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
#     test_rmse_a, test_rmse_w))

In [None]:
log_path = 'logs/lstm_one_head_param_search_progress_no_oggm_IS_2025-11-25.csv'
best_params = get_best_params_for_lstm(log_path, select_by='avg_test_loss')
custom_params = best_params
custom_params['two_heads'] = False

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_2025-11-28_no_oggm_IS_norm_y_past.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_combd_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_combd)

train_dl_combd, val_dl_combd = ds_train_combd_copy.make_loaders(
    train_idx=train_idx_combd,
    val_idx=val_idx_combd,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_combd = mbm.models.LSTM_MB.build_model_from_params(
    cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

ds_test_combd_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_combd)

test_dl_combd = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_combd_copy, ds_train_combd_copy, batch_size=128, seed=cfg.seed)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model_combd.load_state_dict(state)
test_metrics_combd, test_df_preds_combd = model_combd.evaluate_with_preds(
    device, test_dl_combd, ds_test_combd_copy)
test_rmse_a, test_rmse_w = test_metrics_combd[
    'RMSE_annual'], test_metrics_combd['RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds_combd,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds_combd,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=color_annual,
    color_winter=color_winter,
)

## Sensitivity to past information:

### Truncation Test (how many months matter?):
During inference, we only reveal the first K months to the model and set the rest to 0.

In [None]:
@torch.no_grad()
def evaluate_with_truncation(model, device, dl, ds, k):
    # keep first k months, mask rest 0
    model.eval()
    rows = []
    all_keys = ds.keys
    i = 0

    for batch in dl:
        bs = batch["x_m"].shape[0]
        keys = all_keys[i:i + bs]
        i += bs

        batch = model.to_device(device, batch)
        x_m = batch["x_m"].clone()

        # mask months >= k
        x_m[:, k:, :] = 0.0

        _, y_w, y_a = model(x_m, batch["x_s"], batch["mv"], batch["mw"],
                            batch["ma"])

        y_true = batch["y"] * ds.y_std.to(device) + ds.y_mean.to(device)
        y_w = y_w * ds.y_std.to(device) + ds.y_mean.to(device)
        y_a = y_a * ds.y_std.to(device) + ds.y_mean.to(device)

        for j in range(bs):
            g, yr, mid, per = keys[j]
            target = float(y_true[j].cpu())
            pred = float((y_w if per == "winter" else y_a)[j].cpu())
            rows.append({"pred": pred, "target": target, "PERIOD": per})

    df = pd.DataFrame(rows)
    rmse_w = np.sqrt(
        np.mean((df[df.PERIOD == "winter"].pred -
                 df[df.PERIOD == "winter"].target)**2))
    rmse_a = np.sqrt(
        np.mean((df[df.PERIOD == "annual"].pred -
                 df[df.PERIOD == "annual"].target)**2))
    return rmse_w, rmse_a


@torch.no_grad()
@torch.no_grad()
def evaluate_with_truncation_from_end(model, device, dl, ds, k):
    """
    k = number of visible months counting backward from the end.
    Now ensures we ONLY evaluate on samples that actually have
    at least one valid/used month in the truncated visible region.
    """
    model.eval()
    rows = []
    all_keys = ds.keys
    i = 0
    T = 15

    for batch in dl:
        bs = batch["x_m"].shape[0]
        keys = all_keys[i:i + bs]
        i += bs

        batch = model.to_device(device, batch)

        x_m = batch["x_m"].clone()
        mv = batch["mv"]  # valid-month mask  (1 where month is used)
        mw = batch["mw"]  # winter mask
        ma = batch["ma"]  # annual mask

        # determine region kept from end
        idx_start = T - k

        # mask the early months in x_m
        x_m[:, :idx_start, :] = 0.0

        # forward pass
        _, y_w, y_a = model(x_m, batch["x_s"], mv, mw, ma)

        # denormalize
        y_true = batch["y"] * ds.y_std.to(device) + ds.y_mean.to(device)
        y_w = y_w * ds.y_std.to(device) + ds.y_mean.to(device)
        y_a = y_a * ds.y_std.to(device) + ds.y_mean.to(device)

        # select only samples that actually use ANY of the visible months
        # i.e., mv[j, idx_start:T] contains at least one 1
        valid = (mv[:, idx_start:T].sum(dim=1) > 0)

        for j in range(bs):
            if not valid[j]:    # skip if sample does not use any visible months
                continue
            g, yr, mid, per = keys[j]
            target = float(y_true[j].cpu())
            pred = float((y_w if per == "winter" else y_a)[j].cpu())
            rows.append({"pred": pred, "target": target, "PERIOD": per})

    df = pd.DataFrame(rows)

    # compute RMSE only on samples that survived filtering
    def rmse(period):
        sub = df[df.PERIOD == period]
        return np.sqrt(np.mean((sub.pred - sub.target) ** 2)) if len(sub) > 0 else np.nan

    rmse_w = rmse("winter")
    rmse_a = rmse("annual")
    return rmse_w, rmse_a

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# -----------------------------
#  Forward truncation (first K months visible)
# -----------------------------
Ks = list(range(1, 16))  # 1..15 months included
month_idx_last_winter = 10  # cap winter curve at 9 (1-based)

rmse_annual = []
rmse_winter = []

for k in Ks:
    rmse_w, rmse_a = evaluate_with_truncation(model_combd, device,
                                              test_dl_combd,
                                              ds_test_combd_copy, k)
    rmse_annual.append(rmse_a)
    rmse_winter.append(rmse_w)
    print(
        f"K(forward)={k:02d} → RMSE annual={rmse_a:.3f}, winter={rmse_w:.3f}")

# -----------------------------
#  Backward truncation (last K months visible)
# -----------------------------
rmse_annual_end = []
rmse_winter_end = []

for k in Ks:
    rmse_w, rmse_a = evaluate_with_truncation_from_end(model_combd, device,
                                                       test_dl_combd,
                                                       ds_test_combd_copy, k)
    rmse_winter_end.append(rmse_w)
    rmse_annual_end.append(rmse_a)
    print(
        f"K(backward)={k:02d} → RMSE annual={rmse_a:.3f}, winter={rmse_w:.3f}")

# -----------------------------
#  Cap both winter curves at 9 months max
# -----------------------------
Ks_winter = Ks[:month_idx_last_winter]
rmse_winter_fwd = rmse_winter[:len(Ks_winter)]

# -----------------------------
#  X-labels (month names)
# -----------------------------
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]
x_labels = [month_order[k - 1] for k in Ks]

# -----------------------------
#  Reverse backward curves so "more info → right"
# -----------------------------
rmse_annual_bwd_plot = rmse_annual_end[::-1]
rmse_winter_end = rmse_winter_end[::-1]
rmse_winter_bwd_plot = rmse_winter_end[:len(Ks_winter)]

In [None]:
# -----------------------------
#  Plotting
# -----------------------------
fig, ax = plt.subplots(figsize=(9, 5))

ax.plot(Ks,
        rmse_annual,
        marker="o",
        linewidth=2,
        color="tab:purple",
        label="Annual MB (forward)")
ax.plot(Ks,
        rmse_annual_bwd_plot,
        marker="s",
        linewidth=2,
        color="tab:pink",
        label="Annual MB (backward)")

ax.plot(Ks_winter,
        rmse_winter_fwd,
        marker="o",
        linestyle="--",
        linewidth=2,
        color="tab:cyan",
        label="Winter MB (forward)")
ax.plot(Ks_winter,
        rmse_winter_bwd_plot,
        marker="s",
        linestyle="--",
        linewidth=2,
        color="tab:blue",
        label="Winter MB (backward)")

ax.set_xticks(Ks)
ax.set_xticklabels(x_labels, rotation=45, ha="right")
ax.set_xlabel("Months visible to the LSTM")
ax.set_ylabel("RMSE")
ax.set_title("LSTM truncation sensitivity — forward vs backward temporal information")
ax.grid(True, linestyle="--", alpha=0.4)

# --- LEGEND BELOW FIGURE ---
legend = ax.legend(
    ncol=2,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.40),   # shift legend downward
    frameon=True,
)

plt.tight_layout()
plt.show()

### Sensitivity to input:

In [None]:
seed_all(cfg.seed)
ds_train_combd = build_combined_LSTM_dataset(
    df_loss=data_train,
    df_full=data_train_Aug_,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=months_head_pad_Aug_,
    months_tail_pad=months_tail_pad_Aug_,
    normalize_target=False,
    expect_target=True)
train_idx_combd, val_idx_combd = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_combd), val_ratio=0.2, seed=cfg.seed)

log_path = 'logs/lstm_one_head_param_search_progress_no_oggm_IS_2025-11-25.csv'
best_params = get_best_params_for_lstm(log_path, select_by='avg_test_loss')

custom_params = best_params
custom_params['two_heads'] = False

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_2025-11-28_no_oggm_IS_original_y_past.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_combd_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_combd)

train_dl_combd, val_dl_combd = ds_train_combd_copy.make_loaders(
    train_idx=train_idx_combd,
    val_idx=val_idx_combd,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_combd = mbm.models.LSTM_MB.build_model_from_params(
    cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

ds_test_combd_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_combd)

test_dl_combd = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_combd_copy, ds_train_combd_copy, batch_size=128, seed=cfg.seed)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model_combd.load_state_dict(state)
test_metrics_combd, test_df_preds_combd = model_combd.evaluate_with_preds(
    device, test_dl_combd, ds_test_combd_copy)
test_rmse_a, test_rmse_w = test_metrics_combd[
    'RMSE_annual'], test_metrics_combd['RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
month_order = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]
month_to_idx = {m: i for i, m in enumerate(month_order)}

In [None]:
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import torch


@torch.no_grad()
def sensitivity_all_features_to_target_month(
    model,
    device,
    dl,
    ds,
    monthly_cols,
    target_month="jul",
    boost_factor=3.0,
    month_order=None,
):
    if month_order is None:
        month_order = [
            "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr",
            "may", "jun", "jul", "aug", "sep", "oct_"
        ]

    month_to_idx = {m: i for i, m in enumerate(month_order)}
    idx_target = month_to_idx[target_month]  # loop only until this month

    model.eval()
    y_std = ds.y_std.to(device)

    # result[feat][month_idx] = Δ physical MB
    results = {feat: [] for feat in monthly_cols}

    # total number of iterations = n_features × n_months_before_target
    total_iters = len(monthly_cols) * idx_target
    pbar = tqdm(total=total_iters, desc="ΔMB sensitivity", leave=True)

    for feat_i, feat in enumerate(monthly_cols):
        for m_idx in range(
                idx_target
        ):  # only months preceding July (inclusive if you choose)
            deltas = []
            all_keys = ds.keys
            i0 = 0

            for batch in dl:
                bs = batch["x_m"].shape[0]
                keys = all_keys[i0:i0 + bs]
                i0 += bs

                batch = model.to_device(device, batch)
                x_m, x_s = batch["x_m"], batch["x_s"]
                mv, mw, ma = batch["mv"], batch["mw"], batch["ma"]

                # baseline
                y_month_base, _, _ = model(x_m, x_s, mv, mw, ma)

                # perturb
                x_m_pert = x_m.clone()
                x_m_pert[:, m_idx, feat_i] *= boost_factor
                y_month_pert, _, _ = model(x_m_pert, x_s, mv, mw, ma)

                delta = (y_month_pert[:, idx_target] -
                         y_month_base[:, idx_target]) * y_std

                for j in range(bs):
                    g, yr, mid, per = keys[j]
                    if per == "annual":  # keep annual MB only
                        deltas.append(float(delta[j].cpu()))

            results[feat].append(np.mean(deltas))
            pbar.update(1)

    pbar.close()

    df = pd.DataFrame(results, index=month_order[:idx_target]).T
    df.columns.name = "month"
    df.index.name = "feature"
    return df

In [None]:
all_keys = ds_test_combd_copy.keys
i0 = 0

y_month_base_all = []

for batch in test_dl_combd:
    bs = batch["x_m"].shape[0]
    keys = all_keys[i0:i0 + bs]
    i0 += bs

    batch = model_combd.to_device(device, batch)
    x_m, x_s = batch["x_m"], batch["x_s"]
    mv, mw, ma = batch["mv"], batch["mw"], batch["ma"]

    # forward pass
    y_month_base, _, _ = model_combd(x_m, x_s, mv, mw, ma)
    y_month_base_all.append(y_month_base.cpu())   # move to CPU to free GPU memory

# Concatenate
y_month_base_all = torch.cat(y_month_base_all, dim=0)   # (N_samples, n_months)

y_min = float(y_month_base_all.min())
y_max = float(y_month_base_all.max())

print(f"Min monthly MB prediction: {y_min:.3f}")
print(f"Max monthly MB prediction: {y_max:.3f}")

In [None]:
feat_columns = ['t2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'pcsr']

df_delta_MB = sensitivity_all_features_to_target_month(
    model=model_combd,
    device=device,
    dl=test_dl_combd,
    ds=ds_test_combd_copy,
    monthly_cols=MONTHLY_COLS,
    target_month="aug",
    boost_factor=3.0,
    month_order=month_order,  # your 15-month hydrological order
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.ndimage import gaussian_filter1d

# --- Copy sensitivity matrix and map features to long names ---
piv = df_delta_MB.copy()
piv["feature_long"] = piv.index.map(lambda x: vois_climate_long_name.get(x, x))
piv = piv.set_index("feature_long")

# --- Reorder columns according to hydrological month order (up to the target month) ---
piv = piv[[m for m in month_order if m in piv.columns]]

# --- Order features by average |ΔMB| (ascending = strongest at top visually) ---
feat_order = piv.abs().mean(axis=1).sort_values(ascending=True).index

# --- Smooth individual curves ---
piv_smooth = pd.DataFrame(
    np.vstack([gaussian_filter1d(piv.loc[f], sigma=1) for f in feat_order]),
    index=feat_order,
    columns=piv.columns,
)

fig, ax = plt.subplots(figsize=(10, 11))
palette = sns.color_palette("magma", n_colors=len(feat_order))
month_idx = np.arange(len(piv_smooth.columns))

offset_step = np.nanmax(abs(piv_smooth.values)) * 0.85
offset = 0

for feat, color in zip(feat_order, palette):
    y = piv_smooth.loc[feat].values

    # ridge curve
    ax.plot(month_idx, y + offset, color=color, lw=2)
    ax.fill_between(month_idx, offset, y + offset, color=color, alpha=0.4)

    # feature label on left
    ax.text(-0.6, offset, feat, va='center', ha='right', fontsize=13)

    # ---- annotate max (positive ΔMB) ----
    max_i = np.argmax(y)
    ax.text(
        max_i,
        y[max_i] + offset + 0.03 * offset_step,
        f"+{piv.iloc[piv.index.get_loc(feat), max_i]:.2f}",
        ha="center",
        va="bottom",
        fontsize=10,
        bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=1.2),
    )

    # ---- annotate min (negative ΔMB) ----
    min_i = np.argmin(y)
    ax.text(
        min_i,
        y[min_i] + offset - 0.04 * offset_step,
        f"{piv.iloc[piv.index.get_loc(feat), min_i]:.2f}",
        ha="center",
        va="top",
        fontsize=10,
        bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=1.2),
    )

    offset += offset_step

# formatting
ax.set_yticks([])
ax.set_xticks(month_idx)
ax.set_xticklabels([m.strip("_").capitalize() for m in piv_smooth.columns],
                   rotation=45,
                   ha="right")
ax.set_xlabel("Month boosted (×3)")
ax.set_title(
    "ΔMB Sensitivity — Effect of Monthly Climate Perturbations on July MB",
    pad=20)

ax.set_facecolor("white")
fig.patch.set_facecolor("white")
for spine in ["top", "right", "left"]:
    ax.spines[spine].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.tick_params(colors="black")

plt.tight_layout()
plt.show()