In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.French_Alps.scripts.config_FR import *
from regions.French_Alps.scripts.dataset import get_stakes_data_FR
from regions.French_Alps.scripts.utils import *

from regions.Switzerland.scripts.dataset import process_or_load_data, get_CV_splits
from regions.Switzerland.scripts.plotting import plot_predictions_summary, plot_individual_glacier_pred, plot_history_lstm, get_cmap_hex,plot_tsne_overlap, plot_feature_kde_overlap
from regions.Switzerland.scripts.dataset import get_stakes_data, build_combined_LSTM_dataset, inspect_LSTM_sample, prepare_monthly_dfs_with_padding
from regions.Switzerland.scripts.models import compute_seasonal_scores

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.FranceConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Cross-Regional Transfer Learning (Switzerland â†’ French Alps)

This approach uses the Swiss dataset to try and model French Alps glaciers.

### Create Combined Swiss and French Alps Glacier Dataset

Start with point mass balance measurements and transform them to monthly format with ERA5 climate data.

In [None]:
# Read in
data_FR = get_stakes_data_FR(cfg)
data_CH = get_stakes_data(cfg)

# Adjust dfs to match
data_CH = data_CH.drop(
    columns=['aspect_sgi', 'slope_sgi', 'topo_sgi', 'asvf', 'opns'],
    errors='ignore')
data_CH['GLACIER_ZONE'] = ''
data_CH['DATA_MODIFICATION'] = ''

print('Number FR glaciers:', data_FR['GLACIER'].nunique())
print('FR glaciers:', data_FR['GLACIER'].unique())
print('Number CH glaciers:', data_CH['GLACIER'].nunique())
print('CH glaciers:', data_CH['GLACIER'].unique())

In [None]:
# Clean PERIOD column just in case
data_FR["PERIOD"] = data_FR["PERIOD"].str.strip().str.lower()
data_CH["PERIOD"]  = data_CH["PERIOD"].str.strip().str.lower()

fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharey=True)

for ax, period in zip(axes, ["annual", "winter"]):
    mb_nor = data_FR.loc[data_FR.PERIOD == period, "POINT_BALANCE"].dropna()
    mb_ch  = data_CH.loc[data_CH.PERIOD == period, "POINT_BALANCE"].dropna()

    # Common bins for fair comparison
    all_vals = np.concatenate([mb_nor, mb_ch])
    bins = np.linspace(all_vals.min(), all_vals.max(), 21)

    ax.hist(mb_nor, bins=bins, alpha=0.6, label="France")
    ax.hist(mb_ch,  bins=bins, alpha=0.6, label="Switzerland")

    ax.axvline(mb_nor.mean(), linestyle="--")
    ax.axvline(mb_ch.mean(), linestyle="--")

    ax.set_title(f"{period.capitalize()} Mass Balance")
    ax.set_xlabel("Mass balance [m w.e.]")
    ax.legend()

axes[0].set_ylabel("Number of measurements")

plt.suptitle("Seasonal Point Mass Balance Distribution", fontsize=14)
plt.tight_layout()
plt.show()

## France only (within region):

In [None]:
# TEST_GLACIERS = ['Talefre', 'Grands Montets', 'Saint Sorlin']
TRAIN_GLACIERS = ['Blanc', 'Mer de Glace', 'Leschaux', 'Sarennes','Saint Sorlin', 'Grands Montets']
TEST_GLACIERS = ['Talefre', 'Argentiere', 'Gebroulaz']

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_GLACIOCLIM_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Alps.nc")
}

res_FR = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_FR,
    region_name="FR",
    region_id=11,
    paths=paths,
    test_glaciers=TEST_GLACIERS,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True)

df_train_FR = res_FR["df_train"]
df_test_FR = res_FR["df_test"]
df_train_FR_Aug = res_FR["df_train_aug"]
df_test_FR_Aug = res_FR["df_test_aug"]

### Feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(df_train_FR,
                        df_test_FR,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)

In [None]:
palette = {'Train': color_dark_blue, 'Test': '#b2182b'}
fig = plot_feature_kde_overlap(
    df_train_FR,
    df_test_FR,
    STATIC_COLS + MONTHLY_COLS + ['POINT_BALANCE'],
    palette,
    outfile="figures/app_feature_kde_overlap_FR.png")

### LSTM:

In [None]:
mbm.utils.seed_all(cfg.seed)

ds_train_FR = build_combined_LSTM_dataset(
    df_loss=df_train_FR,
    df_full=df_train_FR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_FR['months_head_pad'],
    months_tail_pad=res_FR['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_test_FR = build_combined_LSTM_dataset(
    df_loss=df_test_FR,
    df_full=df_test_FR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_FR['months_head_pad'],
    months_tail_pad=res_FR['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_FR), val_ratio=0.2, seed=cfg.seed)

month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    res_FR['months_head_pad'], res_FR['months_tail_pad'])
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

inspect_LSTM_sample(ds_train_FR, 0, month_labels=month_order)

In [None]:
best_params = {
    'Fm': len(MONTHLY_COLS),
    'Fs': len(STATIC_COLS),
    'hidden_size': 64,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': 128,
    'static_dropout': 0.2,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.05,
    'loss_spec': None
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_{current_date}_FR.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
mbm.utils.seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_FR)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_FR)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename):
        os.remove(model_filename)
        print(f"Deleted existing model file: {model_filename}")

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Load and evaluate on test
# model_filename = f"models/lstm_model_2026-01-02_OOS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

In [None]:
test_df_preds['gl_elv'] = 0

fig, axs = plt.subplots(1, 3, figsize=(32, 15), sharex=True)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]
axs = plot_individual_glacier_pred(test_df_preds,
                                   axs=axs,
                                   subplot_labels=subplot_labels,
                                   color_annual=mbm.plots.COLOR_ANNUAL,
                                   color_winter=mbm.plots.COLOR_WINTER,
                                   custom_order=None,
                                   gl_area={})

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

handles = [legend_scatter_annual, legend_scatter_winter]
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

## Cross-regional (train CH - test FR):

In [None]:
# Train/Test Splitting Strategy
# Spatial Generalization Approach:** Select test set based on glaciers, remaining glaciers will be the train set
# 5-10% or 653 IDs in train set
# TEST_GLACIERS = [
#     'Argentiere', 'Gebroulaz', 'Mer de Glace', 'Saint Sorlin', 'Blanc'
# ]

TEST_GLACIERS = data_FR.GLACIER.unique().tolist()

# Merge FR with CH
data_CH_FR = pd.concat([data_FR, data_CH], axis=0).reset_index(drop=True)
display(len(data_CH_FR['GLACIER'].unique()))
data_CH_FR.head(2)

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_GLACIOCLIM_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Alps.nc")
}

res_CH_FR = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_CH_FR,
    region_name="FR",
    region_id=11,
    paths=paths,
    test_glaciers=TEST_GLACIERS,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True,
    output_file_monthly='CH_FR_wgms_dataset_monthly.csv',
    output_file_monthly_aug='CH_FR_wgms_dataset_monthly_Aug.csv')

df_train_CH_FR = res_CH_FR["df_train"]
df_test_CH_FR = res_CH_FR["df_test"]
df_train_CH_FR_Aug = res_CH_FR["df_train_aug"]
df_test_CH_FR_Aug = res_CH_FR["df_test_aug"]

### Feature overlap:

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(df_train_CH_FR,
                        df_test_CH_FR,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)

In [None]:
palette = {'Train': color_dark_blue, 'Test': '#b2182b'}
fig = plot_feature_kde_overlap(
    df_train_FR,
    df_test_FR,
    STATIC_COLS + MONTHLY_COLS + ['POINT_BALANCE'],
    palette,
    outfile="figures/app_feature_kde_overlap_CH_FR.png")

### LSTM:

In [None]:
mbm.utils.seed_all(cfg.seed)

ds_train_CH_FR = build_combined_LSTM_dataset(
    df_loss=df_train_CH_FR,
    df_full=df_train_CH_FR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH_FR['months_head_pad'],
    months_tail_pad=res_CH_FR['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_test_CH_FR = build_combined_LSTM_dataset(
    df_loss=df_test_CH_FR,
    df_full=df_test_CH_FR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH_FR['months_head_pad'],
    months_tail_pad=res_CH_FR['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_CH_FR), val_ratio=0.2, seed=cfg.seed)

month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    res_CH_FR['months_head_pad'], res_CH_FR['months_tail_pad'])
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

In [None]:
best_params = {
    'Fm': len(MONTHLY_COLS),
    'Fs': len(STATIC_COLS),
    'hidden_size': 64,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.2,
    'static_layers': 2,
    'static_hidden': 128,
    'static_dropout': 0.2,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.05,
    'loss_spec': None
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_{current_date}_CH_FR.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
mbm.utils.seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH_FR)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_CH_FR)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename):
        os.remove(model_filename)
        print(f"Deleted existing model file: {model_filename}")

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Load and evaluate on test
# model_filename = f"models/lstm_model_2026-01-02_OOS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

In [None]:
gl_per_el = data_FR[data_FR.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

shapefile_path = os.path.join(cfg.dataPath,
                              "GLAMOS/RGI/nsidc0770_11.rgi60.CentralEurope",
                              "11_rgi60_CentralEurope.shp")

gl_area = get_gl_area_FR(data_FR, shapefile_path)

test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

fig, axs = plt.subplots(3, 3, figsize=(32, 15), sharex=True)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = plot_individual_glacier_pred(
    test_df_preds,
    axs=axs,
    subplot_labels=subplot_labels,
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
    custom_order=test_gl_per_el,
    gl_area=gl_area,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()