## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings
from collections import defaultdict
from typing import List, Dict, Tuple, Optional

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import ElasticNetCV
from sklearn.metrics import r2_score, mean_squared_error

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *
from scripts.probing import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
# Read ERA5-Land data
era5_ds = xr.open_dataset(cfg.dataPath + path_ERA5_raw +
                          'era5_monthly_averaged_data.nc')
era5_ds

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

colors = get_cmap_hex(cm.batlow, 10)
color_annual = "#c51b7d"
color_winter = colors[0]

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'sd',  # snow depth
    'smlt',  # snow melt
    'sf',  # snow fall
    'rsn',
    'snowc'
]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = True
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_probing.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

### Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

## Probing:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'two_heads': True,
    'head_dropout': 0.0,
    'loss_spec': None
}

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Evaluate on test
# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

model_filename = f"models/lstm_model_2025-10-22_two_heads_no_oggm.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

### Probing dataloaders:

In [None]:
# --- Probe targets and representation to use ---
probe_targets = ["sd", "smlt", "sf", "snowc",
                 "rsn"]  # snow depth, snow melt, snowfall
representation = "c"  # 'c' (cell state) or 'h' (hidden outputs)

# --- ElasticNet hyperparams (interpretable, robust to correlation) ---
alpha_grid = np.logspace(-3, 1, 12)
l1_ratio_grid = [0.1, 0.3, 0.5, 0.7, 0.9]
n_jobs_enet = -1

device = next(model.parameters()).device
model.eval()
for p in model.parameters():
    p.requires_grad_(False)

MONTHLY_COLS_TRAIN = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'pcsr',
    'ELEVATION_DIFFERENCE'
]  # <-- the exact set used to train the LSTM (order matters)

PROBE_COLS = probe_targets  # side-car targets (monthly), not inputs
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

# 1) Build pristine datasets
ds_probe = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS_TRAIN,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    probe_cols=PROBE_COLS)

ds_probe_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS_TRAIN,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    probe_cols=PROBE_COLS)

# 2) Clone ONLY the train dataset once (pristine copy to be mutated)
ds_probe_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_probe)

# 3) Fit scalers + transform by creating train/val loaders
train_dl_probe, val_dl_probe = ds_probe_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # <-- this fits scalers on TRAIN and transforms in-place
    shuffle_train=False,
    use_weighted_sampler=False)

# 4) Now build the TEST loader, copying scalers from the *fitted* train dataset
ds_probe_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_probe_test)

test_dl_probe = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_probe_test_copy,  # will be transformed in-place
    ds_probe_copy,  # <-- has scalers populated after step 3
    batch_size=128,
    seed=cfg.seed)

In [None]:
results_rep = []  # probing on representation
results_raw = []  # baseline on raw inputs

# Extract per split with progress bars
df_tr = extract_probe_dataframe(model,
                                train_dl_probe,
                                probe_targets,
                                MONTHLY_COLS,
                                rep=representation,
                                drop_input=None,
                                split_name="TRAIN",
                                show_progress=True)
df_va = extract_probe_dataframe(model,
                                val_dl_probe,
                                probe_targets,
                                MONTHLY_COLS,
                                rep=representation,
                                drop_input=None,
                                split_name="VAL",
                                show_progress=True)
df_te = extract_probe_dataframe(model,
                                test_dl_probe,
                                probe_targets,
                                MONTHLY_COLS,
                                rep=representation,
                                drop_input=None,
                                split_name="TEST",
                                show_progress=True)

df_fit = pd.concat([df_tr, df_va], ignore_index=True)


In [None]:
# Check unscaling function
def unscale_target_column(df: pd.DataFrame, col: str, mean: float, std: float):
    if std is None or std == 0:
        return df[col]  # avoid divide-by-zero surprises
    return df[col] * std + mean


df_fit_unscaled = df_fit.copy()
for k in probe_targets:
    mu = float(ds_probe_copy.probe_mean[k])
    st = float(ds_probe_copy.probe_std[k])
    df_fit_unscaled[k] = unscale_target_column(df_fit, k, mu, st)

fig, axs = plt.subplots(1, 3, figsize=(12, 5))
rhone_df = df_train[df_train.GLACIER == 'rhone']
rhone_df.groupby('YEAR').sd.mean().plot(ax=axs[0])

rhone_df = df_fit[df_fit.GLACIER == 'rhone']
rhone_df.groupby('YEAR').sd.mean().plot(ax=axs[1])

rhone_df = df_fit_unscaled[df_fit.GLACIER == 'rhone']
rhone_df.groupby('YEAR').sd.mean().plot(ax=axs[2])

In [None]:
# Evaluate each target with a progress bar
for tgt in tqdm(probe_targets, desc=f"Fit/Eval targets", leave=False):
    # Representation probe
    met_rep = eval_train_test(df_fit,
                              df_te,
                              tgt,
                              use_raw=False,
                              alpha_grid=alpha_grid,
                              l1_ratio_grid=l1_ratio_grid,
                              n_jobs_enet=n_jobs_enet)
    results_rep.append({"drop_input": "None", "target": tgt, **met_rep})
    # Raw-input baseline
    met_raw = eval_train_test(df_fit,
                              df_te,
                              tgt,
                              use_raw=True,
                              alpha_grid=alpha_grid,
                              l1_ratio_grid=l1_ratio_grid,
                              n_jobs_enet=n_jobs_enet)
    results_raw.append({"drop_input": "None", "target": tgt, **met_raw})

df_rep_summary = pd.DataFrame(results_rep).sort_values(["target"])
df_raw_summary = pd.DataFrame(results_raw).sort_values(["target"])

print("== Probe on LSTM representation ==")
display(df_rep_summary)
print("== Baseline on raw monthly inputs ==")
display(df_raw_summary)

### Snowdepth:

In [None]:
rep_probe, raw_probe = fit_global_probes(df_fit,
                                         target="sd",
                                         alphas=alpha_grid,
                                         l1_ratios=l1_ratio_grid)
target_mean_test = float(ds_probe_test_copy.probe_mean["sd"])
target_std_test = float(ds_probe_test_copy.probe_std["sd"])

target_mean_train = float(ds_probe_copy.probe_mean["sd"])
target_std_train = float(ds_probe_copy.probe_std["sd"])

In [None]:
for g in TEST_GLACIERS:
    plot_probe_vs_baseline_for_glacier(
        df_te=df_te,
        df_fit_all=df_fit,
        target="sd",
        glacier_name=g,
        split="TEST",
        target_mean=target_mean_test,
        target_std=target_std_test,
        rep_probe=rep_probe,
        raw_probe=raw_probe,
    )

In [None]:
# for g in train_glaciers:
#     plot_probe_vs_baseline_for_glacier(
#         df_te=df_te,
#         df_fit_all=df_fit,
#         target="sd",
#         glacier_name=g,
#         split="TRAIN",
#         target_mean=target_mean_train,
#         target_std=target_std_train,
#         rep_probe=rep_probe,
#         raw_probe=raw_probe,
#     )

### Snowfall:

In [None]:
rep_probe, raw_probe = fit_global_probes(df_fit,
                                         target="sf",
                                         alphas=alpha_grid,
                                         l1_ratios=l1_ratio_grid)
target_mean_test = float(ds_probe_test_copy.probe_mean["sf"])
target_std_test = float(ds_probe_test_copy.probe_std["sf"])

target_mean_train = float(ds_probe_copy.probe_mean["sf"])
target_std_train = float(ds_probe_copy.probe_std["sf"])

In [None]:
for g in TEST_GLACIERS:
    plot_probe_vs_baseline_for_glacier(
        df_te=df_te,
        df_fit_all=df_fit,
        target="sf",
        glacier_name=g,
        split="TEST",
        target_mean=target_mean_test,
        target_std=target_std_test,
        rep_probe=rep_probe,
        raw_probe=raw_probe,
    )

### Snowmelt:

In [None]:
rep_probe, raw_probe = fit_global_probes(df_fit,
                                         target="smlt",
                                         alphas=alpha_grid,
                                         l1_ratios=l1_ratio_grid)
target_mean_test = float(ds_probe_test_copy.probe_mean["smlt"])
target_std_test = float(ds_probe_test_copy.probe_std["smlt"])

target_mean_train = float(ds_probe_copy.probe_mean["smlt"])
target_std_train = float(ds_probe_copy.probe_std["smlt"])

In [None]:
for g in TEST_GLACIERS:
    plot_probe_vs_baseline_for_glacier(
        df_te=df_te,
        df_fit_all=df_fit,
        target="smlt",
        glacier_name=g,
        split="TEST",
        target_mean=target_mean_test,
        target_std=target_std_test,
        rep_probe=rep_probe,
        raw_probe=raw_probe,
    )

### Snow cover:

In [None]:
rep_probe, raw_probe = fit_global_probes(df_fit,
                                         target="snowc",
                                         alphas=alpha_grid,
                                         l1_ratios=l1_ratio_grid)
target_mean_test = float(ds_probe_test_copy.probe_mean["snowc"])
target_std_test = float(ds_probe_test_copy.probe_std["snowc"])

target_mean_train = float(ds_probe_copy.probe_mean["snowc"])
target_std_train = float(ds_probe_copy.probe_std["snowc"])

In [None]:
for g in TEST_GLACIERS:
    plot_probe_vs_baseline_for_glacier(
        df_te=df_te,
        df_fit_all=df_fit,
        target="snowc",
        glacier_name=g,
        split="TEST",
        target_mean=target_mean_test,
        target_std=target_std_test,
        rep_probe=rep_probe,
        raw_probe=raw_probe,
    )

### Snow density:

In [None]:
rep_probe, raw_probe = fit_global_probes(df_fit,
                                         target="rsn",
                                         alphas=alpha_grid,
                                         l1_ratios=l1_ratio_grid)
target_mean_test = float(ds_probe_test_copy.probe_mean["rsn"])
target_std_test = float(ds_probe_test_copy.probe_std["rsn"])

target_mean_train = float(ds_probe_copy.probe_mean["rsn"])
target_std_train = float(ds_probe_copy.probe_std["rsn"])

In [None]:
for g in TEST_GLACIERS:
    plot_probe_vs_baseline_for_glacier(
        df_te=df_te,
        df_fit_all=df_fit,
        target="rsn",
        glacier_name=g,
        split="TEST",
        target_mean=target_mean_test,
        target_std=target_std_test,
        rep_probe=rep_probe,
        raw_probe=raw_probe,
    )