In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.French_Alps.scripts.config_FR import *
from regions.French_Alps.scripts.dataset import get_stakes_data_FR
from regions.French_Alps.scripts.utils import *

from regions.Switzerland.scripts.dataset import process_or_load_data, get_CV_splits
from regions.Switzerland.scripts.plotting import plot_predictions_summary, plot_individual_glacier_pred, plot_history_lstm, get_cmap_hex,plot_tsne_overlap, plot_feature_kde_overlap, pred_vs_truth_density
from regions.Switzerland.scripts.dataset import get_stakes_data, build_combined_LSTM_dataset, inspect_LSTM_sample, prepare_monthly_dfs_with_padding
from regions.Switzerland.scripts.models import compute_seasonal_scores, get_best_params_for_lstm

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.FranceConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

# Cross-Regional Transfer Learning (Switzerland → France)

This approach uses the Swiss dataset to try and model France glaciers.

## Create Combined Swiss and France Glacier Dataset

Start with point mass balance measurements and transform them to monthly format with ERA5 climate data.

In [None]:
# Read in
data_FR = get_stakes_data_FR(cfg)
data_CH = get_stakes_data(cfg)

# Adjust dfs to match
data_CH = data_CH.drop(
    columns=['aspect_sgi', 'slope_sgi', 'topo_sgi', 'asvf', 'opns'],
    errors='ignore')
data_CH['GLACIER_ZONE'] = ''
data_CH['DATA_MODIFICATION'] = ''

print('Number FR glaciers:', data_FR['GLACIER'].nunique())
print('FR glaciers:', data_FR['GLACIER'].unique())
print('Number CH glaciers:', data_CH['GLACIER'].nunique())
print('CH glaciers:', data_CH['GLACIER'].unique())

In [None]:
# Clean PERIOD column just in case
data_FR["PERIOD"] = data_FR["PERIOD"].str.strip().str.lower()
data_CH["PERIOD"] = data_CH["PERIOD"].str.strip().str.lower()

fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharey=True)

for ax, period in zip(axes, ["annual", "winter"]):
    mb_nor = data_FR.loc[data_FR.PERIOD == period, "POINT_BALANCE"].dropna()
    mb_ch = data_CH.loc[data_CH.PERIOD == period, "POINT_BALANCE"].dropna()

    # Common bins for fair comparison
    all_vals = np.concatenate([mb_nor, mb_ch])
    bins = np.linspace(all_vals.min(), all_vals.max(), 21)

    ax.hist(mb_nor, bins=bins, alpha=0.6, label="France")
    ax.hist(mb_ch, bins=bins, alpha=0.6, label="Switzerland")

    ax.axvline(mb_nor.mean(), linestyle="--")
    ax.axvline(mb_ch.mean(), linestyle="--")

    ax.set_title(f"{period.capitalize()} Mass Balance")
    ax.set_xlabel("Mass balance [m w.e.]")
    ax.legend()

axes[0].set_ylabel("Number of measurements")

plt.suptitle("Seasonal Point Mass Balance Distribution", fontsize=14)
plt.tight_layout()
plt.show()

## Fine-tuning FR datasets:

In [None]:
paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_GLACIOCLIM_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_NOR_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_NOR_Alps.nc")
}

#### 5% split:

In [None]:
# ---------------------------
# 5% FINE-TUNING SPLIT (FR)
# ---------------------------

finetune_glaciers_5pct = ['Sarennes', 'Talefre', 'Grands Montets', 'Leschaux']

# All remaining glaciers = holdout set
all_nor_glaciers = list(data_FR['GLACIER'].unique())
holdout_glaciers_5pct = [
    g for g in all_nor_glaciers if g not in finetune_glaciers_5pct
]

data_FR_ft_5pct = data_FR[data_FR['GLACIER'].isin(
    finetune_glaciers_5pct)].copy()
data_FR_holdout_5pct = data_FR[~data_FR['GLACIER'].isin(finetune_glaciers_5pct
                                                        )].copy()

print(
    f"5% fine-tuning glaciers ({len(finetune_glaciers_5pct)}): {finetune_glaciers_5pct}"
)
print(
    f"Hold-out glaciers ({len(holdout_glaciers_5pct)}): {holdout_glaciers_5pct}"
)


In [None]:
res_FR_5pct = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_FR,
    region_name="FR",
    region_id=11,
    paths=paths,
    test_glaciers=holdout_glaciers_5pct,  # holdout = test set
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True,
    output_file_monthly='FR_5pct_ft_dataset_monthly.csv',
    output_file_monthly_aug='FR_5pct_ft_dataset_monthly_Aug.csv')

df_ft_FR_5pct = res_FR_5pct["df_train"]
df_holdout_FR_5pct = res_FR_5pct["df_test"]
df_ft_FR_5pct_Aug = res_FR_5pct["df_train_aug"]
df_holdout_FR_5pct_Aug = res_FR_5pct["df_test_aug"]


In [None]:
mbm.utils.seed_all(cfg.seed)

ds_ft_FR_5pct = build_combined_LSTM_dataset(
    df_loss=df_ft_FR_5pct,
    df_full=df_ft_FR_5pct_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_FR_5pct['months_head_pad'],
    months_tail_pad=res_FR_5pct['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_holdout_FR_5pct = build_combined_LSTM_dataset(
    df_loss=df_holdout_FR_5pct,
    df_full=df_holdout_FR_5pct_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_FR_5pct['months_head_pad'],
    months_tail_pad=res_FR_5pct['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

assert set(df_ft_FR_5pct.GLACIER.unique()) == set(finetune_glaciers_5pct)
assert set(df_holdout_FR_5pct.GLACIER.unique()).isdisjoint(
    set(finetune_glaciers_5pct))

## In sample CH dataset (used for the pretrained model):

In [None]:
res_CH = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_CH,
    region_name="CH",
    region_id=11,
    paths=paths,
    test_glaciers=[],
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=False,
    add_pcsr=False,
    output_file_monthly='CH_wgms_dataset_monthly_LSTM_IS.csv',
    output_file_monthly_aug='CH_wgms_dataset_monthly_LSTM_Aug_IS.csv')

df_train = res_CH["df_train"]
df_train_Aug = res_CH["df_train_aug"]

# Check that train set contains all glaciers
existing_glaciers = set(df_train.GLACIER.unique())
print('Number of glaciers in train data:', len(existing_glaciers))

mbm.utils.seed_all(cfg.seed)
ds_train_CH = build_combined_LSTM_dataset(
    df_loss=df_train,  # hydrological-year POINT_BALANCE
    df_full=df_train_Aug,  # August-anchored monthly sequences
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH['months_head_pad'],
    months_tail_pad=res_CH['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx_CH, val_idx_CH = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_CH), val_ratio=0.2, seed=cfg.seed)

### Train CH model (w/o pcsr)

In [None]:
best_params = {
    "Fm": 8,
    "Fs": 3,
    "hidden_size": 96,
    "num_layers": 2,
    "bidirectional": False,
    "dropout": 0.2,
    "static_layers": 1,
    "static_hidden": 128,
    "static_dropout": 0.3,
    "lr": 0.0005,
    "weight_decay": 1e-05,
    "loss_name": "neutral",
    "two_heads": False,
    "head_dropout": 0.0,
    "loss_spec": None,
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_CH_model_{current_date}_IS_norm_y_past.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_CH_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

train_dl_CH, val_dl_CH = ds_train_CH_copy.make_loaders(
    train_idx=train_idx_CH,
    val_idx=val_idx_CH,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_CH = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model_CH.train_loop(
        device=device,
        train_dl=train_dl_CH,
        val_dl=val_dl_CH,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_CH_copy, batch_size=128, seed=cfg.seed)

# Load and evaluate on test
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH.load_state_dict(state)
test_metrics, test_df_preds = model_CH.evaluate_with_preds(
    device, test_dl, ds_test_copy)

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

### Load pre-trained CH model:

In [None]:
# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_CH_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

train_dl, val_dl = ds_train_CH_copy.make_loaders(
    train_idx=train_idx_CH,
    val_idx=val_idx_CH,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_CH = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH.load_state_dict(state)

## Fine-tuning:

### 5% split:

In [None]:
# ---------------------------------------
# DATALOADERS for FR 5% fine-tune split
# ---------------------------------------

# pristine clone (fine-tune set)
ds_ft_FR_5pct_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_ft_FR_5pct)

# split indices on FR 5%-ft
train_idx_FR_5pct, val_idx_FR_5pct = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_ft_FR_5pct_copy), val_ratio=0.2, seed=cfg.seed)

# IMPORTANT: copy CH scalers -> FR 5%-ft, then transform in-place
ds_ft_FR_5pct_copy.set_scalers_from(ds_train_CH_copy)
ds_ft_FR_5pct_copy.transform_inplace()

# create loaders WITHOUT fitting scalers
ft_train_dl_FR_5pct, ft_val_dl_FR_5pct = ds_ft_FR_5pct_copy.make_loaders(
    train_idx=train_idx_FR_5pct,
    val_idx=val_idx_FR_5pct,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=False,  # <-- key!
    shuffle_train=True,
    use_weighted_sampler=True)

# holdout loader (FR 5% split)
ds_holdout_FR_5pct_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_holdout_FR_5pct)

holdout_dl_FR_5pct = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_holdout_FR_5pct_copy, ds_train_CH_copy, batch_size=128, seed=cfg.seed)

#### “Safe” fine-tune for small FR-ft set (freeze LSTM, train only static+head):

In [None]:
# “Safe” fine-tune for small FR-ft set (freeze LSTM, train only static+head):
# --- build model, resolve loss, train, reload best ---
model_CH_ft = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft.load_state_dict(state)

# 1) freeze recurrent encoder
for name, p in model_CH_ft.named_parameters():
    if name.startswith("lstm."):
        p.requires_grad = False

# 2) new optimizer on trainable params only (small LR)
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model_CH_ft.parameters()),
    lr=1e-4,
    weight_decay=best_params["weight_decay"],
)

# 3) fine-tune
history, best_val, best_state = model_CH_ft.train_loop(
    device=device,
    train_dl=ft_train_dl_FR_5pct,
    val_dl=ft_val_dl_FR_5pct,
    epochs=60,
    optimizer=opt,
    clip_val=1.0,
    loss_fn=loss_fn,
    es_patience=8,
    save_best_path="models/lstm_finetuned_CH_to_FR_5pct.pt",
    verbose=True,
)

#### “Full” fine-tune (unfreeze everything, very small LR):

In [None]:
# --- build model, resolve loss, train, reload best ---
model_CH_ft_2 = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft_2.load_state_dict(state)

# unfreeze everything
for p in model_CH_ft_2.parameters():
    p.requires_grad = True

opt = torch.optim.AdamW(
    model_CH_ft_2.parameters(),
    lr=1e-5,  # smaller because we’re updating the LSTM too
    weight_decay=best_params["weight_decay"],
)

history, best_val, best_state = model_CH_ft_2.train_loop(
    device=device,
    train_dl=ft_train_dl_FR_5pct,
    val_dl=ft_val_dl_FR_5pct,
    epochs=80,
    optimizer=opt,
    clip_val=1.0,
    loss_fn=loss_fn,
    es_patience=10,
    save_best_path="models/lstm_finetuned_CH_to_FR_full_5pct.pt",
)

#### Best practice: two-stage fine-tune

In [None]:
# --- build model, resolve loss, train, reload best ---
model_CH_ft_3 = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft_3.load_state_dict(state)

# Stage 1: freeze LSTM, tune heads
for name, p in model_CH_ft_3.named_parameters():
    p.requires_grad = not name.startswith("lstm.")

opt1 = torch.optim.AdamW(filter(lambda p: p.requires_grad,
                                model_CH_ft_3.parameters()),
                         lr=2e-4,
                         weight_decay=best_params["weight_decay"])

model_CH_ft_3.train_loop(device,
                         ft_train_dl_FR_5pct,
                         ft_val_dl_FR_5pct,
                         epochs=20,
                         optimizer=opt1,
                         loss_fn=loss_fn,
                         es_patience=5,
                         save_best_path="models/tmp_stage1.pt")

# Stage 2: unfreeze all, very small LR
for p in model_CH_ft_3.parameters():
    p.requires_grad = True

opt2 = torch.optim.AdamW(model_CH_ft_3.parameters(),
                         lr=1e-5,
                         weight_decay=best_params["weight_decay"])

history, best_val, best_state = model_CH_ft_3.train_loop(
    device=device,
    train_dl=ft_train_dl_FR_5pct,
    val_dl=ft_val_dl_FR_5pct,
    epochs=60,
    optimizer=opt2,
    loss_fn=loss_fn,
    es_patience=10,
    save_best_path="models/lstm_finetuned_CH_to_FR_2stage_5pct.pt",
)

#### Compare fine-tuning methods:

In [None]:
def eval_and_scores(model_, holdout_dl, ds_holdout_copy):
    test_metrics, df_preds = model_.evaluate_with_preds(
        device, holdout_dl, ds_holdout_copy)
    scores_annual, scores_winter = compute_seasonal_scores(df_preds,
                                                           target_col="target",
                                                           pred_col="pred")
    return test_metrics, df_preds, scores_annual, scores_winter


def add_metrics_box(ax, scores_annual, scores_winter, title=None):
    if title:
        ax.set_title(title, fontsize=18)

    legend_txt = "\n".join([
        r"$\mathrm{RMSE_a}=%.2f$, $\mathrm{RMSE_w}=%.2f$" %
        (scores_annual["rmse"], scores_winter["rmse"]),
        r"$\mathrm{R^2_a}=%.2f$, $\mathrm{R^2_w}=%.2f$" %
        (scores_annual["R2"], scores_winter["R2"]),
        r"$\mathrm{Bias_a}=%.2f$, $\mathrm{Bias_w}=%.2f$" %
        (scores_annual["Bias"], scores_winter["Bias"]),
    ])
    ax.text(
        0.02,
        0.98,
        legend_txt,
        transform=ax.transAxes,
        va="top",
        fontsize=14,
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.6),
    )


methods = [
    ("No fine-tune (CH→FR)", model_CH),
    ("Heads-only FT (freeze LSTM)", model_CH_ft),
    ("Full FT (unfreeze all)", model_CH_ft_2),
    ("Two-stage FT", model_CH_ft_3),
]

results = []
for name, m in methods:
    test_metrics, df_preds, s_a, s_w = eval_and_scores(
        m, holdout_dl_FR_5pct, ds_holdout_FR_5pct_copy)
    results.append((name, df_preds, s_a, s_w))
    print(name, "| RMSE annual:", test_metrics["RMSE_annual"],
          "| RMSE winter:", test_metrics["RMSE_winter"])

fig, axes = plt.subplots(2, 2, figsize=(18, 14), sharex=True, sharey=True)
axes = axes.ravel()

for ax, (name, df_preds, s_a, s_w) in zip(axes, results):
    pred_vs_truth_density(
        ax,
        df_preds,
        s_a,
        add_legend=False,
        palette=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
        ax_xlim=(-14, 8),
        ax_ylim=(-14, 8),
    )
    add_metrics_box(ax, s_a, s_w, title=name)

fig.supxlabel("Observed PMB [m w.e.]", fontsize=20)
fig.supylabel("Modeled PMB [m w.e.]", fontsize=20)
plt.tight_layout()
plt.show()