In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.Iceland_mb.scripts.config_ICE import *
from regions.Iceland_mb.scripts.dataset import get_stakes_data_ICE
from regions.Iceland_mb.scripts.utils import *

from regions.Switzerland.scripts.dataset import process_or_load_data, get_CV_splits
from regions.Switzerland.scripts.plotting import plot_predictions_summary, plot_individual_glacier_pred, plot_history_lstm, get_cmap_hex,plot_tsne_overlap, plot_feature_kde_overlap, alpha_labels, pred_vs_truth_density
from regions.Switzerland.scripts.dataset import get_stakes_data, build_combined_LSTM_dataset, inspect_LSTM_sample, prepare_monthly_dfs_with_padding
from regions.Switzerland.scripts.models import compute_seasonal_scores, get_best_params_for_lstm

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.IcelandConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Cross-Regional Transfer Learning (Switzerland â†’ Iceland)

This approach uses the Swiss dataset to try and model Norwegian glaciers.

### Create Combined Swiss and Norwegian Glacier Dataset

Start with point mass balance measurements and transform them to monthly format with ERA5 climate data.

In [None]:
# Read in
data_ICE = get_stakes_data_ICE(cfg)
data_CH = get_stakes_data(cfg)

# Adjust dfs to match
data_CH = data_CH.drop(
    columns=['aspect_sgi', 'slope_sgi', 'topo_sgi', 'asvf', 'opns'],
    errors='ignore')
data_CH['GLACIER_ZONE'] = ''
data_CH['DATA_MODIFICATION'] = ''

print('Number ICE glaciers:', data_ICE['GLACIER'].nunique())
print('ICE glaciers:', data_ICE['GLACIER'].unique())
print('Number CH glaciers:', data_CH['GLACIER'].nunique())
print('CH glaciers:', data_CH['GLACIER'].unique())

In [None]:
# Clean PERIOD column just in case
data_ICE["PERIOD"] = data_ICE["PERIOD"].str.strip().str.lower()
data_CH["PERIOD"] = data_CH["PERIOD"].str.strip().str.lower()

fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharey=True)

for ax, period in zip(axes, ["annual", "winter"]):
    mb_nor = data_ICE.loc[data_ICE.PERIOD == period, "POINT_BALANCE"].dropna()
    mb_ch = data_CH.loc[data_CH.PERIOD == period, "POINT_BALANCE"].dropna()

    # Common bins for fair comparison
    all_vals = np.concatenate([mb_nor, mb_ch])
    bins = np.linspace(all_vals.min(), all_vals.max(), 21)

    ax.hist(mb_nor, bins=bins, alpha=0.6, label="Iceland")
    ax.hist(mb_ch, bins=bins, alpha=0.6, label="Switzerland")

    ax.axvline(mb_nor.mean(), linestyle="--")
    ax.axvline(mb_ch.mean(), linestyle="--")

    ax.set_title(f"{period.capitalize()} Mass Balance")
    ax.set_xlabel("Mass balance [m w.e.]")
    ax.legend()

axes[0].set_ylabel("Number of measurements")

plt.suptitle("Seasonal Point Mass Balance Distribution", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

## Iceland only (within region):

In [None]:
TEST_GLACIERS = [
    'RGI60-06.00311', 'RGI60-06.00305', 'Thjorsarjoekull (Hofsjoekull E)',
    'RGI60-06.00445', 'RGI60-06.00474', 'RGI60-06.00425', 'RGI60-06.00480',
    'Dyngjujoekull', 'RGI60-06.00478', 'Koeldukvislarjoekull',
    'Oeldufellsjoekull', 'RGI60-06.00350', 'RGI60-06.00340'
]

TRAIN_GLACIERS = data_ICE.loc[~data_ICE.GLACIER.isin(TEST_GLACIERS),
                              "GLACIER"].unique()

print(f"Test glaciers ({len(TEST_GLACIERS)}):\n", TEST_GLACIERS)
print(f"Train glaciers ({len(TRAIN_GLACIERS)}):\n", TRAIN_GLACIERS)

# Get average areas per glaciers
gl_per_el = data_ICE[data_ICE.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
shapefile_path = os.path.join(cfg.dataPath, "RGI_v6/RGI_06_Iceland",
                              "06_rgi60_Iceland.shp")
gl_area = get_gl_area_ICE(data_ICE, shapefile_path)

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_WGMS_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_NOR_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_NOR_Alps.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

res_ICE = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_ICE,
    region_name="ICE",
    region_id=8,
    paths=paths,
    test_glaciers=TEST_GLACIERS,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True)

df_train_ICE = res_ICE["df_train"]
df_test_ICE = res_ICE["df_test"]
df_train_ICE_Aug = res_ICE["df_train_aug"]
df_test_ICE_Aug = res_ICE["df_test_aug"]

### Feature overlap:

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(df_train_ICE,
                        df_test_ICE,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)

In [None]:
palette = {'Train': color_dark_blue, 'Test': '#b2182b'}
fig = plot_feature_kde_overlap(
    df_train_ICE,
    df_test_ICE,
    STATIC_COLS + MONTHLY_COLS + ['POINT_BALANCE'],
    palette,
    outfile="figures/app_feature_kde_overlap_ICE.png")

### Train LSTM:

In [None]:
mbm.utils.seed_all(cfg.seed)

ds_train_ICE = build_combined_LSTM_dataset(
    df_loss=df_train_ICE,
    df_full=df_train_ICE_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_ICE['months_head_pad'],
    months_tail_pad=res_ICE['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_test_ICE = build_combined_LSTM_dataset(
    df_loss=df_test_ICE,
    df_full=df_test_ICE_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_ICE['months_head_pad'],
    months_tail_pad=res_ICE['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_ICE), val_ratio=0.2, seed=cfg.seed)

month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    res_ICE['months_head_pad'], res_ICE['months_tail_pad'])
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

inspect_LSTM_sample(ds_train_ICE, 0, month_labels=month_order)

In [None]:
log_path = 'logs/lstm_param_search_progress_OOS_ICE_2026-02-09.csv'
optimal_params = get_best_params_for_lstm(log_path, select_by='avg_test_loss')
print(optimal_params)
df = pd.read_csv(log_path)
# df = df[df['test_rmse_w'] < 0.4]
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="avg_test_loss", inplace=True)
df.iloc[0]

In [None]:
best_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 96,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.1,
    'static_layers': 1,
    'static_hidden': 32,
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_{current_date}_ICE.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
mbm.utils.seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_ICE)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_ICE)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename):
        os.remove(model_filename)
        print(f"Deleted existing model file: {model_filename}")

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Load model
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)

#### Evaluate on Test:

In [None]:
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

fig = plt.figure(figsize=(15, 10))
ax1 = plt.subplot(1, 1, 1)

pred_vs_truth_density(
    ax1,
    test_df_preds,
    scores_annual,
    add_legend=False,
    palette=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
    ax_xlim=(-8, 6),
    ax_ylim=(-8, 6),
)

legend_NN = "\n".join([
    r"$\mathrm{RMSE_a}=%.2f$, $\mathrm{RMSE_w}=%.2f$" %
    (scores_annual["rmse"], scores_winter["rmse"]),
    r"$\mathrm{R^2_a}=%.2f$, $\mathrm{R^2_w}=%.2f$" %
    (scores_annual["R2"], scores_winter["R2"]),
    r"$\mathrm{Bias_a}=%.2f$, $\mathrm{Bias_w}=%.2f$" %
    (scores_annual["Bias"], scores_winter["Bias"]),
])
ax1.text(
    0.02,
    0.98,
    legend_NN,
    transform=ax1.transAxes,
    verticalalignment="top",
    fontsize=20,
    bbox=dict(boxstyle="round", facecolor="white", alpha=0.5),
)

In [None]:
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)
test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(5, 3, figsize=(32, 30), sharex=True)

subplot_labels = alpha_labels(len(TEST_GLACIERS))

axs = plot_individual_glacier_pred(test_df_preds,
                                   axs=axs,
                                   subplot_labels=subplot_labels,
                                   color_annual=mbm.plots.COLOR_ANNUAL,
                                   color_winter=mbm.plots.COLOR_WINTER,
                                   custom_order=test_gl_per_el,
                                   gl_area=gl_area)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

handles = [legend_scatter_annual, legend_scatter_winter]
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

#### Evaluate on Train:

In [None]:
# ds_train_copy must be the SAME object that had fit_scalers called (via make_loaders(..., fit_and_transform=True))
assert ds_train_copy.month_mean is not None
assert ds_train_copy.y_std is not None

# Build an eval dataset (pristine) and transform it using train scalers
ds_train_eval = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_ICE)

train_eval_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_train_eval, ds_train_copy, batch_size=128, seed=cfg.seed)

train_metrics, train_df_preds = model.evaluate_with_preds(
    device, train_eval_dl, ds_train_eval)

scores_annual, scores_winter = compute_seasonal_scores(train_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

fig = plt.figure(figsize=(15, 10))
ax1 = plt.subplot(1, 1, 1)

pred_vs_truth_density(
    ax1,
    train_df_preds,
    scores_annual,
    add_legend=False,
    palette=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
    ax_xlim=(-14, 12),
    ax_ylim=(-14, 12),
)

legend_NN = "\n".join([
    r"$\mathrm{RMSE_a}=%.2f$, $\mathrm{RMSE_w}=%.2f$" %
    (scores_annual["rmse"], scores_winter["rmse"]),
    r"$\mathrm{R^2_a}=%.2f$, $\mathrm{R^2_w}=%.2f$" %
    (scores_annual["R2"], scores_winter["R2"]),
    r"$\mathrm{Bias_a}=%.2f$, $\mathrm{Bias_w}=%.2f$" %
    (scores_annual["Bias"], scores_winter["Bias"]),
])
ax1.text(
    0.02,
    0.98,
    legend_NN,
    transform=ax1.transAxes,
    verticalalignment="top",
    fontsize=20,
    bbox=dict(boxstyle="round", facecolor="white", alpha=0.5),
)

In [None]:
train_df_preds['gl_elv'] = train_df_preds['GLACIER'].map(gl_per_el)
train_gl_per_elv = gl_per_el[TRAIN_GLACIERS].sort_values().index

fig, axs = plt.subplots(5, 4, figsize=(32, 30), sharex=True)

subplot_labels = alpha_labels(len(TRAIN_GLACIERS))

axs = plot_individual_glacier_pred(train_df_preds,
                                   axs=axs,
                                   subplot_labels=subplot_labels,
                                   color_annual=mbm.plots.COLOR_ANNUAL,
                                   color_winter=mbm.plots.COLOR_WINTER,
                                   custom_order=train_gl_per_elv,
                                   gl_area=gl_area,
                                   ax_xlim=(-14, 8),
                                   ax_ylim=(-14, 8))

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

handles = [legend_scatter_annual, legend_scatter_winter]
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

## Cross-regional (train CH - test ICE):

In [None]:
TEST_GLACIERS = data_ICE.GLACIER.unique().tolist()

# Merge ICE with CH
data_CH_ICE = pd.concat([data_ICE, data_CH], axis=0).reset_index(drop=True)
display(len(data_CH_ICE['GLACIER'].unique()))
data_CH_ICE.head(2)

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_WGMS_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_ICE_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_ICE_Alps.nc")
}

res_CH_ICE = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_CH_ICE,
    region_name="ICE",
    region_id=8,
    paths=paths,
    test_glaciers=TEST_GLACIERS,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True,
    output_file_monthly='CH_ICE_wgms_dataset_monthly.csv',
    output_file_monthly_aug='CH_ICE_wgms_dataset_monthly_Aug.csv')

df_train_CH_ICE = res_CH_ICE["df_train"]
df_test_CH_ICE = res_CH_ICE["df_test"]
df_train_CH_ICE_Aug = res_CH_ICE["df_train_aug"]
df_test_CH_ICE_Aug = res_CH_ICE["df_test_aug"]

### Feature overlap:

In [None]:
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

fig = plot_tsne_overlap(df_train_CH_ICE,
                        df_test_CH_ICE,
                        STATIC_COLS,
                        MONTHLY_COLS,
                        sublabels=("a", "b", "c"),
                        label_fmt="({})",
                        label_xy=(0.02, 0.98),
                        label_fontsize=14,
                        n_iter=1000,
                        random_state=cfg.seed,
                        custom_palette=custom_palette)

In [None]:
palette = {'Train': color_dark_blue, 'Test': '#b2182b'}
fig = plot_feature_kde_overlap(
    df_train_CH_ICE,
    df_test_CH_ICE,
    STATIC_COLS + MONTHLY_COLS + ['POINT_BALANCE'],
    palette,
    outfile="figures/app_feature_kde_overlap_CH_ICE.png")

### LSTM:

In [None]:
mbm.utils.seed_all(cfg.seed)

ds_train_CH_ICE = build_combined_LSTM_dataset(
    df_loss=df_train_CH_ICE,
    df_full=df_train_CH_ICE_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH_ICE['months_head_pad'],
    months_tail_pad=res_CH_ICE['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_test_CH_ICE = build_combined_LSTM_dataset(
    df_loss=df_test_CH_ICE,
    df_full=df_test_CH_ICE_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH_ICE['months_head_pad'],
    months_tail_pad=res_CH_ICE['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_CH_ICE), val_ratio=0.2, seed=cfg.seed)

month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    res_CH_ICE['months_head_pad'], res_CH_ICE['months_tail_pad'])
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

In [None]:
best_params = {
    'Fm': len(MONTHLY_COLS),
    'Fs': len(STATIC_COLS),
    "hidden_size": 96,
    "num_layers": 2,
    "bidirectional": False,
    "dropout": 0.2,
    "static_layers": 1,
    "static_hidden": 128,
    "static_dropout": 0.3,
    "lr": 0.0005,
    "weight_decay": 1e-05,
    "loss_name": "neutral",
    "two_heads": False,
    "head_dropout": 0.0,
    "loss_spec": None,
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_{current_date}_CH_ICE.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
mbm.utils.seed_all(cfg.seed)
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH_ICE)
ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_CH_ICE)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename):
        os.remove(model_filename)
        print(f"Deleted existing model file: {model_filename}")

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

# Load and evaluate on test
# model_filename = f"models/lstm_model_2026-01-02_OOS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
# scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
#                                                        target_col='target',
#                                                        pred_col='pred')

# print("Annual scores:", scores_annual)
# print("Winter scores:", scores_winter)

# fig = plot_predictions_summary(
#     grouped_ids=test_df_preds,
#     scores_annual=scores_annual,
#     scores_winter=scores_winter,
#     ax_xlim=(-14, 6),
#     ax_ylim=(-14, 6),
#     color_annual=mbm.plots.COLOR_ANNUAL,
#     color_winter=mbm.plots.COLOR_WINTER,
# )

test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

fig = plt.figure(figsize=(15, 10))
ax1 = plt.subplot(1, 1, 1)

pred_vs_truth_density(
    ax1,
    test_df_preds,
    scores_annual,
    add_legend=False,
    palette=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
    ax_xlim=(-14, 8),
    ax_ylim=(-14, 8),
)

legend_NN = "\n".join([
    r"$\mathrm{RMSE_a}=%.2f$, $\mathrm{RMSE_w}=%.2f$" %
    (scores_annual["rmse"], scores_winter["rmse"]),
    r"$\mathrm{R^2_a}=%.2f$, $\mathrm{R^2_w}=%.2f$" %
    (scores_annual["R2"], scores_winter["R2"]),
    r"$\mathrm{Bias_a}=%.2f$, $\mathrm{Bias_w}=%.2f$" %
    (scores_annual["Bias"], scores_winter["Bias"]),
])
ax1.text(
    0.02,
    0.98,
    legend_NN,
    transform=ax1.transAxes,
    verticalalignment="top",
    fontsize=20,
    bbox=dict(boxstyle="round", facecolor="white", alpha=0.5),
)

In [None]:
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)
test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(8, 4, figsize=(40, 30), sharex=True)

subplot_labels = alpha_labels(len(TEST_GLACIERS))

axs = plot_individual_glacier_pred(test_df_preds,
                                   axs=axs,
                                   subplot_labels=subplot_labels,
                                   color_annual=mbm.plots.COLOR_ANNUAL,
                                   color_winter=mbm.plots.COLOR_WINTER,
                                   custom_order=test_gl_per_el,
                                   gl_area=gl_area,
                                   ax_xlim=(-14, 8),
                                   ax_ylim=(-14, 8))

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

handles = [legend_scatter_annual, legend_scatter_winter]
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()