## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.utils import *
from scripts.glamos import *
from scripts.models import *
from scripts.geo_data import *
from scripts.dataset import *
from scripts.geodetic import *
from scripts.physical import *
from scripts.plotting import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
mbm.utils.seed_all(cfg.seed)
mbm.plots.use_mbm_style()

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    mbm.utils.free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input data:

In [None]:
# Read GLAMOS stake data
data_glamos = get_stakes_data(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train = data_monthly[data_monthly.YEAR < 2025]

existing_glaciers = set(data_monthly_train.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly_train[data_monthly_train.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train['y'] = data_train['POINT_BALANCE']

data_test = data_monthly_train[(data_monthly_train.GLACIER == 'rhone')
                               & (data_monthly_train.YEAR >= 2000)]
data_test['y'] = data_test['POINT_BALANCE']

# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.YEAR < 2025]

existing_glaciers = set(data_monthly_train_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_train_Aug_[data_monthly_train_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

# Test (Rhone > 2000)
data_test_Aug_ = data_monthly_train_Aug_[
    (data_monthly_train_Aug_.GLACIER == 'rhone')
    & (data_monthly_train_Aug_.YEAR >= 2000)]
data_test_Aug_['y'] = data_test_Aug_['POINT_BALANCE']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
CACHE_TRAIN_DS = "cache/lstm_train_dataset_rhone.pt"
CACHE_TEST_DS = "cache/lstm_test_dataset_rhone.pt"
os.makedirs("cache", exist_ok=True)

# ============================================================
# Load or build TRAIN dataset (all glaciers, all years)
# ============================================================
RUN_CACHE_TRAIN_DS = False
if RUN_CACHE_TRAIN_DS:
    print("Building TRAIN MBSequenceDataset...")

    ds_train = build_combined_LSTM_dataset(
        df_loss=data_train,
        df_full=data_train_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=True,
        expect_target=True,
    )

    torch.save({"dataset": ds_train}, CACHE_TRAIN_DS)
    print("Cached TRAIN dataset.")
else:
    print("Loading cached TRAIN MBSequenceDataset...")
    ckpt = torch.load(CACHE_TRAIN_DS, map_location="cpu")
    ds_train = ckpt["dataset"]

# ============================================================
# Load or build TEST dataset (Rhone glacier ≥ 2000)
# ============================================================
RUN_CACHE_TEST_DS = False
if RUN_CACHE_TEST_DS:
    print("Building TEST MBSequenceDataset...")

    ds_test = build_combined_LSTM_dataset(
        df_loss=data_test,
        df_full=data_test_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=True,
        expect_target=True,
    )

    torch.save({"dataset": ds_test}, CACHE_TEST_DS)
    print("Cached TEST dataset.")

else:
    print("Loading cached TEST MBSequenceDataset...")
    ckpt = torch.load(CACHE_TEST_DS, map_location="cpu")
    ds_test = ckpt["dataset"]

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.0, seed=cfg.seed)

### Define & train model:

In [None]:
################
# custom_params = {
#     'Fm': 9,
#     'Fs': 3,
#     'hidden_size': 64,
#     'num_layers': 2,
#     'bidirectional': False,
#     'dropout': 0.1,
#     'static_layers': 2,
#     'static_hidden': 32,
#     'static_dropout': 0.1,
#     'lr': 0.0005,
#     'weight_decay': 1e-05,
#     'loss_name': 'neutral',
#     'two_heads': False,
#     'head_dropout': 0.0,
#     'loss_spec': None
# }

custom_params = PARAMS_LSTM_IS_past

################

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = LSTM_IS_NORM_Y_PAST

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# Load and evaluate on test
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

# Gridded sensitivity

This code builds a monthly climate input table for every grid cell of Rhone glacier, for all years 2007–2024, in the exact same column format as the stake dataset used to train the LSTM.
So instead of having samples only at stakes, each (gridcell, year) now becomes one LSTM sequence.

In [None]:
path_glacier_grid_glamos = 'GLAMOS/topo/gridded_topo_inputs/GLAMOS_grid_Aug_/'
glacier_name = 'rhone'
fields_not_features = cfg.fieldsNotFeatures

CACHE_GRID_DF = "cache/rhone_grid_monthly_df.parquet"
CACHE_TRAIN_FULL = "cache/train_full_pristine_ds_rhone.pt"
os.makedirs("cache", exist_ok=True)

# ============================================================
# 1) Cache Rhone glacier grid dataframe
# ============================================================
RUN_CACHE_GRID_DF = False

if RUN_CACHE_GRID_DF or not os.path.exists(CACHE_GRID_DF):
    if os.path.exists(CACHE_GRID_DF):
        os.remove(CACHE_GRID_DF)

    print("Reading Rhone glacier parquet files...")

    glacier_path = os.path.join(cfg.dataPath, path_glacier_grid_glamos,
                                glacier_name)
    dataframes = []
    range_years = range(2007, 2025)  # small test range

    for year in tqdm(range_years):
        parquet_path = os.path.join(glacier_path,
                                    f"{glacier_name}_grid_{year}.parquet")
        if not os.path.exists(parquet_path):
            raise FileNotFoundError(parquet_path)

        df = pd.read_parquet(parquet_path)
        df.drop_duplicates(inplace=True)
        dataframes.append(df)

    df_grid_monthly = pd.concat(dataframes, ignore_index=True)

    # Keep only required columns
    REQUIRED = ["GLACIER", "YEAR", "ID", "PERIOD", "MONTHS"]
    all_columns = MONTHLY_COLS + STATIC_COLS + fields_not_features
    needed = set(all_columns) | set(REQUIRED)
    df_grid_monthly = df_grid_monthly[[
        c for c in df_grid_monthly.columns if c in needed
    ]]

    # Fake target variable if missing
    if "POINT_BALANCE" not in df_grid_monthly.columns:
        df_grid_monthly["POINT_BALANCE"] = 0.0

    # Mask extrapolated months
    extrapolate_months = ["aug_", "sep_"]
    df_grid_monthly.loc[
        df_grid_monthly["MONTHS"].str.lower().isin(extrapolate_months),
        "POINT_BALANCE",
    ] = np.nan

    df_grid_monthly.to_parquet(CACHE_GRID_DF)
    print("Cached Rhone glacier grid dataframe.")

else:
    print("Loading cached Rhone glacier grid dataframe...")
    df_grid_monthly = pd.read_parquet(CACHE_GRID_DF)

df_grid_monthly_a = df_grid_monthly.dropna(subset=["ID", "MONTHS"])

# ============================================================
# 2) Cache pristine TRAIN dataset for scalers
# ============================================================
RUN_CACHE_TRAIN_FULL = False

if RUN_CACHE_TRAIN_FULL or not os.path.exists(CACHE_TRAIN_FULL):
    if os.path.exists(CACHE_TRAIN_FULL):
        os.remove(CACHE_TRAIN_FULL)

    print("Building pristine TRAIN dataset for scalers...")

    ds_train_full = build_combined_LSTM_dataset(
        df_loss=data_train,
        df_full=data_train_Aug_,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad_Aug_,
        months_tail_pad=months_tail_pad_Aug_,
        normalize_target=True,
        expect_target=True,
    )

    torch.save({"dataset": ds_train_full}, CACHE_TRAIN_FULL)
    print("Cached pristine TRAIN dataset.")

else:
    print("Loading cached pristine TRAIN dataset...")
    ckpt = torch.load(CACHE_TRAIN_FULL, map_location="cpu")
    ds_train_full = ckpt["dataset"]

# ============================================================
# 3) Fit scalers (fast)
# ============================================================
ds_train_full_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_full)
ds_train_full_copy.fit_scalers(train_idx)

In [None]:
# ============================================================
# Build LSTM dataset for the full Rhone glacier grid (all years)
# ============================================================
# Each grid cell and year becomes one LSTM sequence.
# This produces:
#   x_m : (Ncells × Ny, Nmonths=16, Nfeatures=9)   monthly climate inputs
#   x_s : (Ncells × Ny, 3)                         static topo features
#
CACHE_GL_DS = "cache/rhone_grid_MBSequenceDataset_norm.pt"
os.makedirs("cache", exist_ok=True)

RUN_CACHE_GL_DS = False

if RUN_CACHE_GL_DS:
    print("Building Rhone glacier MBSequenceDataset from dataframe...")

    ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_grid_monthly_a,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,  # dummy target required
        show_progress=True,
        normalize_target=True,  # we only predict & backprop
    )

    torch.save({"dataset": ds_gl_a}, CACHE_GL_DS)
    print("Cached Rhone glacier MBSequenceDataset.")
else:
    print("Loading cached Rhone glacier MBSequenceDataset...")
    ckpt = torch.load(CACHE_GL_DS, map_location="cpu")
    ds_gl_a = ckpt["dataset"]


In [None]:
# ============================================================
# Normalize Rhone glacier grid with TRAINING statistics
# ============================================================
# The grid inputs are standardized using the SAME means/stds
# that were fitted on the multi-glacier training set.
# This guarantees physical consistency between training and grid inference.
#
test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_gl_a,
    ds_train_full_copy,  # contains the fitted scalers
    seed=cfg.seed,
    batch_size=128,
)

# ============================================================
# Load trained regional LSTM glacier model
# ============================================================
model = get_lstm_model_cpu_cached(cfg, custom_params, model_filename)
device = torch.device("cpu")

# ============================================================
# Prepare grid dataset for sensitivity analysis
# ============================================================
ds = ds_gl_a
dl = test_gl_dl_a

# Reference climate altitude of Rhone glacier (used to reconstruct
# absolute elevation from ELEVATION_DIFFERENCE)
alt = np.unique(data_train[data_train.YEAR >= 2000].ALTITUDE_CLIMATE)[1]
print(f"Reference climate altitude = {alt} m")

# Hydrological month axis including padding
months_keys = (months_tail_pad + mbm.data_processing.utils.months_hydro_year +
               months_head_pad)
print(f"{months_keys=}")

# Dataset geometry
Nsamples = len(ds)  # number of gridcell × year sequences
Nmonths = len(months_keys)  # = 16 monthly time steps
Nfeatures = ds[0]["x_m"].shape[1]  # = 9 climate features per month

print(f"{Nsamples=}")
print(f"{Nmonths=}")
print(f"{Nfeatures=}")

In [None]:
# ============================================================
# Allocate storage:
# sensitivity[m][b,t,f] = ∂ b_m / ∂ x_{t,f}
# ============================================================
sensitivity = {
    m: torch.zeros(Nsamples, Nmonths, Nfeatures)
    for m in months_keys
}

pred = torch.zeros(Nsamples, Nmonths)
elevation = torch.zeros(Nsamples)

model.eval()  # disable dropout for stable gradients

all_keys = ds.keys
i = 0

pbar = tqdm(total=Nsamples, desc="Computing grid sensitivities", unit="cells")

for batch in dl:
    bs = batch["x_m"].shape[0]

    batch = model.to_device(device, batch)
    x_m = batch["x_m"].clone().requires_grad_(True)

    # Forward pass once per batch
    y_month, y_w, y_a = model(x_m, batch["x_s"], batch["mv"], batch["mw"],
                              batch["ma"])
    assert y_month.shape[1] == Nmonths

    pred[i:i + bs] = y_month.detach()

    # ========================================================
    # Loop over output months
    # ========================================================
    for m_idx, m_name in enumerate(months_keys):

        model.zero_grad()
        if x_m.grad is not None:
            x_m.grad.zero_()

        one_hot = torch.nn.functional.one_hot(
            torch.tensor([m_idx], device=device), Nmonths).float()

        target = (y_month * one_hot).sum()
        target.backward(retain_graph=True)

        sensitivity[m_name][i:i + bs] = x_m.grad.detach()

    # ========================================================
    # Recover absolute elevation
    # ========================================================
    elevation[i:i +
              bs] = ((batch["x_m"] * ds_train_full_copy.month_std.to(device)) +
                     ds_train_full_copy.month_mean.to(device)
                     )[:, 0, MONTHLY_COLS.index("ELEVATION_DIFFERENCE")] + alt

    i += bs
    pbar.update(bs)

pbar.close()

In [None]:
# ============================================================
# 1) Check which samples have zero total sensitivity
# ============================================================
for m in months_keys:
    # Flatten (T,F) -> vector and compute L2 norm per sample
    # This measures the total sensitivity magnitude of b_m to all inputs
    norm_per_sample = sensitivity[m].reshape(sensitivity[m].shape[0],
                                             -1).norm(dim=1)

    # Print how many grid cells have exactly zero sensitivity
    # (usually fully masked or invalid sequences)
    print(m, sensitivity[m][norm_per_sample == 0.0].shape)

# ============================================================
# 2) Build elevation bands across the glacier
# ============================================================
bands = np.linspace(elevation.min(), elevation.max(), 8)

print("Bounds of the bands:", bands)
print("diff bands =", np.diff(bands))  # thickness of each elevation band

# ============================================================
# 3) Group sensitivities by elevation band for each output month
# ============================================================
sens_bands = {m: [] for m in months_keys}

for e, m in enumerate(months_keys):

    # Loop over elevation intervals [lb, ub]
    for i in range(bands.shape[0] - 1):

        lb = bands[i]
        ub = bands[i + 1]

        # Boolean mask selecting grid cells in this elevation band
        ind = (elevation <= ub) * (lb <= elevation)

        # Print number of cells per band only once (for first month)
        if e == 0:
            print(ind.sum())

        # Store sensitivities for this band and this output month
        # Resulting shape: (Ncells_in_band, Nmonths, Nfeatures)
        sens_bands[m].append(sensitivity[m][ind])

##### For special variables:

In [None]:
fig = plot_monthly_sensitivity_elevbands(
    sens_bands=sens_bands,
    plot_var="t2m",
    glacier_name="rhone",
    months_keys=months_keys,
    vois_climate_long_name=vois_climate_long_name,
    monthly_cols=MONTHLY_COLS,
    bands=bands,
    add_panel_labels=False,
    drop_padded_months=True,
)

In [None]:
fig = plot_monthly_sensitivity_elevbands(
    sens_bands=sens_bands,
    plot_var="tp",
    glacier_name="rhone",
    months_keys=months_keys,
    vois_climate_long_name=vois_climate_long_name,
    monthly_cols=MONTHLY_COLS,
    bands=bands,
    add_panel_labels=False,
    drop_padded_months=True,
)

In [None]:
for plot_var in ['slhf', 'sshf', 'ssrd', 'fal', 'str', 'pcsr']:
    fig = plot_monthly_sensitivity_elevbands(
        sens_bands=sens_bands,
        plot_var=plot_var,
        glacier_name="rhone",
        months_keys=months_keys,
        vois_climate_long_name=vois_climate_long_name,
        monthly_cols=MONTHLY_COLS,
        bands=bands,
        add_panel_labels=False,
        drop_padded_months=True)
    plt.close()

##### For unique months:

In [None]:
selected_month = "jul"
plot_vars = ["tp", "t2m", "str", "slhf", "ssrd", "fal", "pcsr", "sshf"]
id_elev_bands = [0, 6]

drop_padded_months = True
keep_idx = month_keep_idx(months_keys, drop_padded_months)

vals = []
for plot_var in plot_vars:
    f_idx = MONTHLY_COLS.index(plot_var)
    for band in sens_bands[selected_month]:
        mean = band[:, :, f_idx].mean(dim=0)[keep_idx]
        std = band[:, :, f_idx].std(dim=0)[keep_idx]
        vals.append((mean -
                     std).min().item())
        vals.append((mean + std).max().item())

ylim = (1.1 * min(vals), 1.1 * max(vals))
print("Global ylim:", ylim)

titles = [
    "Precip",
    "Temp",
    "Surf net therm radiation",
    "Surface latent heat flux",
    "Surface solar radiation downwards",
    "Albedo",
    "Potential clear sky rad.",
    "Surface sensible heat flux",
]

fig, axs = plt.subplots(4, 2, figsize=(12, 8), sharex=True)
ax_list = axs.ravel()

for ax, plot_var, title in zip(ax_list, plot_vars, titles):
    plot_sensitivity_elev_band(
        sens_bands=sens_bands[selected_month],
        plot_var=plot_var,
        text_var=title,
        id_elev_bands=id_elev_bands,
        month_labels=months_keys,
        monthly_cols=MONTHLY_COLS,
        bands=bands,
        ax=ax,
        ylim=ylim,
        drop_padded_months=drop_padded_months,
    )
    ax.set_ylabel("Sens.")

# Turn off any unused axes (only matters if lists don't match grid size)
for ax in ax_list[len(plot_vars):]:
    ax.axis("off")

plt.suptitle(f"Sensitivity for {selected_month.capitalize()}")
plt.tight_layout()
plt.show()

In [None]:
selected_month = "feb"
plot_vars = ["tp", "t2m", "str", "slhf", "ssrd", "fal", "pcsr", "sshf"]
id_elev_bands = [0, 6]

vals = []
for plot_var in plot_vars:
    f_idx = MONTHLY_COLS.index(plot_var)
    for band in sens_bands[selected_month]:
        mean = band[:, :, f_idx].mean(dim=0)
        std = band[:, :, f_idx].std(dim=0)
        vals.append((mean - std).min().item())
        vals.append((mean + std).max().item())

ylim = (1.1 * min(vals), 1.1 * max(vals))
print("Global ylim:", ylim)

fig, axs = plt.subplots(4, 2, figsize=(12, 8), sharex=True)

for ax, plot_var, title in zip(
        axs.ravel(),
        plot_vars,
    [
        "Precip",
        "Temp",
        "Surf net therm radiation",
        "Surface latent heat flux",
        "Surface solar radiation downwards",
        "Albedo",
        "Potential clear sky rad.",
        "Surface sensible heat flux",
    ],
):
    plot_sensitivity_elev_band(
        sens_bands=sens_bands[selected_month],
        plot_var=plot_var,
        text_var=title,
        id_elev_bands=id_elev_bands,
        month_labels=months_keys,
        ax=ax,
        ylim=ylim,
    )

plt.suptitle(f"Sensitivity for {selected_month.capitalize()}")
plt.tight_layout()
plt.show()