In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D
import xarray as xr

from regions.Norway_mb.scripts.config_NOR import *
from regions.Norway_mb.scripts.dataset import get_stakes_data_NOR
from regions.Norway_mb.scripts.utils import *

from regions.Switzerland.scripts.dataset import process_or_load_data, get_CV_splits
from regions.Switzerland.scripts.plotting import plot_predictions_summary, plot_individual_glacier_pred, plot_history_lstm, get_cmap_hex,plot_tsne_overlap, plot_feature_kde_overlap
from regions.Switzerland.scripts.dataset import get_stakes_data, build_combined_LSTM_dataset, inspect_LSTM_sample, prepare_monthly_dfs_with_padding
from regions.Switzerland.scripts.models import compute_seasonal_scores

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.NorwayConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Cross-Regional Transfer Learning (Switzerland → Norway)

This approach uses the Swiss dataset to try and model Norway glaciers.

### Create Combined Swiss and Norway Glacier Dataset

Start with point mass balance measurements and transform them to monthly format with ERA5 climate data.

In [None]:
# Read in
data_NOR = get_stakes_data_NOR(cfg)
data_CH = get_stakes_data(cfg)

# Adjust dfs to match
data_CH = data_CH.drop(
    columns=['aspect_sgi', 'slope_sgi', 'topo_sgi', 'asvf', 'opns'],
    errors='ignore')
data_CH['GLACIER_ZONE'] = ''
data_CH['DATA_MODIFICATION'] = ''

print('Number NOR glaciers:', data_NOR['GLACIER'].nunique())
print('NOR glaciers:', data_NOR['GLACIER'].unique())
print('Number CH glaciers:', data_CH['GLACIER'].nunique())
print('CH glaciers:', data_CH['GLACIER'].unique())

In [None]:
# Clean PERIOD column just in case
data_NOR["PERIOD"] = data_NOR["PERIOD"].str.strip().str.lower()
data_CH["PERIOD"] = data_CH["PERIOD"].str.strip().str.lower()

fig, axes = plt.subplots(1, 2, figsize=(13, 5), sharey=True)

for ax, period in zip(axes, ["annual", "winter"]):
    mb_nor = data_NOR.loc[data_NOR.PERIOD == period, "POINT_BALANCE"].dropna()
    mb_ch = data_CH.loc[data_CH.PERIOD == period, "POINT_BALANCE"].dropna()

    # Common bins for fair comparison
    all_vals = np.concatenate([mb_nor, mb_ch])
    bins = np.linspace(all_vals.min(), all_vals.max(), 21)

    ax.hist(mb_nor, bins=bins, alpha=0.6, label="Norway")
    ax.hist(mb_ch, bins=bins, alpha=0.6, label="Switzerland")

    ax.axvline(mb_nor.mean(), linestyle="--")
    ax.axvline(mb_ch.mean(), linestyle="--")

    ax.set_title(f"{period.capitalize()} Mass Balance")
    ax.set_xlabel("Mass balance [m w.e.]")
    ax.legend()

axes[0].set_ylabel("Number of measurements")

plt.suptitle("Seasonal Point Mass Balance Distribution", fontsize=14)
plt.tight_layout()
plt.show()

### Fine-tuning NOR datasets:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# 50%

finetune_glaciers = [
    'Engabreen', 'Storglombreen N', 'Moesevassbrea', 'Blaaisen', 'Blabreen',
    'Harbardsbreen', 'Graasubreen', 'Svelgjabreen', 'Aalfotbreen',
    'Rundvassbreen', 'Juvfonne', 'Storsteinsfjellbreen', 'Hansebreen',
    'Vesledalsbreen', 'Vetlefjordbreen', 'Blomstoelskardsbreen',
    'Vestre Memurubreen', 'Austre Memurubreen'
]

# Test glaciers (all remaining Norway glaciers)
all_france_glaciers = list(data_NOR['GLACIER'].unique())
holdout_glaciers = [
    g for g in all_france_glaciers if g not in finetune_glaciers
]

data_NOR_ft = data_NOR[data_NOR['GLACIER'].isin(finetune_glaciers)].copy()
data_NOR_holdout = data_NOR[~data_NOR['GLACIER'].isin(finetune_glaciers)].copy(
)

print(f"Fine-tuning glaciers ({len(finetune_glaciers)}): {finetune_glaciers}")
print(f"Hold-out glaciers ({len(holdout_glaciers)}): {holdout_glaciers}")

paths = {
    'csv_path':
    os.path.join(cfg.dataPath, path_PMB_GLACIOCLIM_csv),
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_NOR_Alps.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_NOR_Alps.nc")
}

res_ft = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_NOR,
    region_name="NOR",
    region_id=8,
    paths=paths,
    test_glaciers=holdout_glaciers,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True,
    output_file_monthly='NOR_ft_wgms_dataset_monthly.csv',
    output_file_monthly_aug='NOR_ft_wgms_dataset_monthly_Aug.csv')

df_ft_NOR = res_ft["df_train"]
df_holdout_NOR = res_ft["df_test"]
df_ft_NOR_Aug = res_ft["df_train_aug"]
df_holdout_NOR_Aug = res_ft["df_test_aug"]

mbm.utils.seed_all(cfg.seed)

ds_ft_NOR = build_combined_LSTM_dataset(
    df_loss=df_ft_NOR,
    df_full=df_ft_NOR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_ft['months_head_pad'],
    months_tail_pad=res_ft['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

ds_holdout_NOR = build_combined_LSTM_dataset(
    df_loss=df_holdout_NOR,
    df_full=df_holdout_NOR_Aug,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_ft['months_head_pad'],
    months_tail_pad=res_ft['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

# train_idx_NOR, val_idx_NOR = mbm.data_processing.MBSequenceDataset.split_indices(
#     len(ds_ft_NOR), val_ratio=0.2, seed=cfg.seed)

In [None]:
assert set(df_ft_NOR.GLACIER.unique()) == set(finetune_glaciers)
assert set(df_holdout_NOR.GLACIER.unique()).isdisjoint(set(finetune_glaciers))

### In sample CH dataset (used for the pretrained model):

In [None]:
res_CH = prepare_monthly_dfs_with_padding(
    cfg=cfg,
    df_region=data_CH,
    region_name="CH",
    region_id=11,
    paths=paths,
    test_glaciers=[],
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=False,
    add_pcsr=False,
    output_file_monthly='CH_wgms_dataset_monthly_LSTM_IS.csv',
    output_file_monthly_aug='CH_wgms_dataset_monthly_LSTM_Aug_IS.csv')

df_train = res_CH["df_train"]
df_train_Aug = res_CH["df_train_aug"]

# Check that train set contains all glaciers
existing_glaciers = set(df_train.GLACIER.unique())
print('Number of glaciers in train data:', len(existing_glaciers))

mbm.utils.seed_all(cfg.seed)
ds_train_CH = build_combined_LSTM_dataset(
    df_loss=df_train,  # hydrological-year POINT_BALANCE
    df_full=df_train_Aug,  # August-anchored monthly sequences
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    months_head_pad=res_CH['months_head_pad'],
    months_tail_pad=res_CH['months_tail_pad'],
    normalize_target=True,
    expect_target=True)

train_idx_CH, val_idx_CH = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_CH), val_ratio=0.2, seed=cfg.seed)


### Train CH model (w/o pcsr)

In [None]:
best_params = {
    "Fm": 8,
    "Fs": 3,
    "hidden_size": 96,
    "num_layers": 2,
    "bidirectional": False,
    "dropout": 0.2,
    "static_layers": 1,
    "static_hidden": 128,
    "static_dropout": 0.3,
    "lr": 0.0005,
    "weight_decay": 1e-05,
    "loss_name": "neutral",
    "two_heads": False,
    "head_dropout": 0.0,
    "loss_spec": None,
}

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_CH_model_{current_date}_IS_norm_y_past.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_CH_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

train_dl_CH, val_dl_CH = ds_train_CH_copy.make_loaders(
    train_idx=train_idx_CH,
    val_idx=val_idx_CH,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_CH = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model_CH.train_loop(
        device=device,
        train_dl=train_dl_CH,
        val_dl=val_dl_CH,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_CH_copy, batch_size=128, seed=cfg.seed)

# Load and evaluate on test
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH.load_state_dict(state)
test_metrics, test_df_preds = model_CH.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=test_df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

### Load pre-trained CH model:

In [None]:
best_params = {
    "Fm": 8,
    "Fs": 3,
    "hidden_size": 96,
    "num_layers": 2,
    "bidirectional": False,
    "dropout": 0.2,
    "static_layers": 1,
    "static_hidden": 128,
    "static_dropout": 0.3,
    "lr": 0.0005,
    "weight_decay": 1e-05,
    "loss_name": "neutral",
    "two_heads": False,
    "head_dropout": 0.0,
    "loss_spec": None,
}

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_CH_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_CH)

train_dl, val_dl = ds_train_CH_copy.make_loaders(
    train_idx=train_idx_CH,
    val_idx=val_idx_CH,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model_CH = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH.load_state_dict(state)

### Fine-tuning:

In [None]:
# validation and training NOR set:
# pristine clone
ds_ft_NOR_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_ft_NOR)

# split indices on NOR-ft
train_idx_NOR, val_idx_NOR = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_ft_NOR_copy), val_ratio=0.2, seed=cfg.seed)

# IMPORTANT: copy CH scalers -> NOR, then transform NOR in-place
ds_ft_NOR_copy.set_scalers_from(ds_train_CH_copy)
ds_ft_NOR_copy.transform_inplace()

# now create loaders WITHOUT fitting scalers
ft_train_dl, ft_val_dl = ds_ft_NOR_copy.make_loaders(
    train_idx=train_idx_NOR,
    val_idx=val_idx_NOR,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=False,  # <-- key!
    shuffle_train=True,
    use_weighted_sampler=True  # optional
)

# holdout loader:
ds_holdout_NOR_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_holdout_NOR)
holdout_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_holdout_NOR_copy, ds_train_CH_copy, batch_size=128, seed=cfg.seed)

### Without fine-tuning:

In [None]:
test_metrics, df_preds = model_CH.evaluate_with_preds(device, holdout_dl,
                                                      ds_holdout_NOR_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

#### “Safe” fine-tune for small NOR-ft set (freeze LSTM, train only static+head):

In [None]:
# --- build model, resolve loss, train, reload best ---
model_CH_ft = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft.load_state_dict(state)

# 1) freeze recurrent encoder
for name, p in model_CH_ft.named_parameters():
    if name.startswith("lstm."):
        p.requires_grad = False

# 2) new optimizer on trainable params only (small LR)
opt = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model_CH_ft.parameters()),
    lr=1e-4,
    weight_decay=best_params["weight_decay"],
)

# 3) fine-tune
history, best_val, best_state = model_CH_ft.train_loop(
    device=device,
    train_dl=ft_train_dl,
    val_dl=ft_val_dl,
    epochs=60,
    optimizer=opt,
    clip_val=1.0,
    loss_fn=loss_fn,
    es_patience=8,
    save_best_path="models/lstm_finetuned_CH_to_NOR.pt",
    verbose=True,
)

test_metrics, df_preds = model_CH_ft.evaluate_with_preds(
    device, holdout_dl, ds_holdout_NOR_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

#### “Full” fine-tune (unfreeze everything, very small LR):

In [None]:
# --- build model, resolve loss, train, reload best ---
model_CH_ft_2 = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft_2.load_state_dict(state)

# unfreeze everything
for p in model_CH_ft_2.parameters():
    p.requires_grad = True

opt = torch.optim.AdamW(
    model_CH_ft_2.parameters(),
    lr=1e-5,  # smaller because we’re updating the LSTM too
    weight_decay=best_params["weight_decay"],
)

history, best_val, best_state = model_CH_ft_2.train_loop(
    device=device,
    train_dl=ft_train_dl,
    val_dl=ft_val_dl,
    epochs=80,
    optimizer=opt,
    clip_val=1.0,
    loss_fn=loss_fn,
    es_patience=10,
    save_best_path="models/lstm_finetuned_CH_to_NOR_full.pt",
)

test_metrics, df_preds = model_CH_ft_2.evaluate_with_preds(
    device, holdout_dl, ds_holdout_NOR_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)

#### Best practice: two-stage fine-tune

In [None]:
# --- build model, resolve loss, train, reload best ---
model_CH_ft_3 = mbm.models.LSTM_MB.build_model_from_params(
    cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

# Load
model_filename = f"models/lstm_CH_model_2026-02-09_IS_norm_y_past.pt"
state = torch.load(model_filename, map_location=device)
model_CH_ft_3.load_state_dict(state)

# Stage 1: freeze LSTM, tune heads
for name, p in model_CH_ft_3.named_parameters():
    p.requires_grad = not name.startswith("lstm.")

opt1 = torch.optim.AdamW(filter(lambda p: p.requires_grad,
                                model_CH_ft_3.parameters()),
                         lr=2e-4,
                         weight_decay=best_params["weight_decay"])

model_CH_ft_3.train_loop(device,
                         ft_train_dl,
                         ft_val_dl,
                         epochs=20,
                         optimizer=opt1,
                         loss_fn=loss_fn,
                         es_patience=5,
                         save_best_path="models/tmp_stage1.pt")

# Stage 2: unfreeze all, very small LR
for p in model_CH_ft_3.parameters():
    p.requires_grad = True

opt2 = torch.optim.AdamW(model_CH_ft_3.parameters(),
                         lr=1e-5,
                         weight_decay=best_params["weight_decay"])

history, best_val, best_state = model_CH_ft_3.train_loop(
    device=device,
    train_dl=ft_train_dl,
    val_dl=ft_val_dl,
    epochs=60,
    optimizer=opt2,
    loss_fn=loss_fn,
    es_patience=10,
    save_best_path="models/lstm_finetuned_CH_to_NOR_2stage.pt",
)

test_metrics, df_preds = model_CH_ft_3.evaluate_with_preds(
    device, holdout_dl, ds_holdout_NOR_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(
    grouped_ids=df_preds,
    scores_annual=scores_annual,
    scores_winter=scores_winter,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
)


In [None]:
gl_per_el = data_NOR[data_NOR.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_gl_per_el = gl_per_el[holdout_glaciers].sort_values().index

shapefile_path = os.path.join(cfg.dataPath, "RGI_v6/RGI_08_Scandinavia",
                              "08_rgi60_Scandinavia.shp")

gl_area = get_gl_area_NOR(data_NOR, shapefile_path)

df_preds['gl_elv'] = df_preds['GLACIER'].map(gl_per_el)
test_gl_per_el = gl_per_el[holdout_glaciers].sort_values().index

fig, axs = plt.subplots(5, 3, figsize=(30, 30), sharex=True)

subplot_labels = [
    '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)', '(h)', '(i)'
]

axs = plot_individual_glacier_pred(
    df_preds,
    axs=axs,
    subplot_labels=subplot_labels,
    color_annual=mbm.plots.COLOR_ANNUAL,
    color_winter=mbm.plots.COLOR_WINTER,
    custom_order=test_gl_per_el,
    gl_area=gl_area,
    ax_xlim=(-14, 6),
    ax_ylim=(-14, 6),
)

fig.supxlabel('Observed PMB [m w.e.]', fontsize=20, y=0.06)
fig.supylabel('Modeled PMB [m w.e.]', fontsize=20, x=0.09)
# two distinct handles
legend_scatter_annual = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_ANNUAL,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Annual')

legend_scatter_winter = Line2D([0], [0],
                               marker='o',
                               linestyle='None',
                               linewidth=0,
                               markersize=10,
                               markerfacecolor=mbm.plots.COLOR_WINTER,
                               markeredgecolor='k',
                               markeredgewidth=0.8,
                               label='Winter')

# if you already have other handles (e.g., bands/means), append these:
# handles = existing_handles + [legend_scatter_annual, legend_scatter_winter]
handles = [legend_scatter_annual, legend_scatter_winter]

# You can let matplotlib use the labels from the handles; no need to pass `labels=...`
fig.legend(handles=handles,
           loc='upper center',
           bbox_to_anchor=(0.5, 0.05),
           ncol=4,
           fontsize=20)

plt.subplots_adjust(hspace=0.25, wspace=0.25)
plt.show()

In [None]:
nigardsbreen = data_NOR[data_NOR.GLACIER == 'Nigardsbreen']
nigardsbreen_w = nigardsbreen[nigardsbreen.PERIOD == 'winter']

df_w = df_preds[(df_preds["GLACIER"] == "Nigardsbreen")
                & (df_preds["PERIOD"] == "winter")].copy()

df_w = df_w.sort_values("YEAR")

plt.figure(figsize=(8, 6))

sc = plt.scatter(df_w["target"],
                 df_w["pred"],
                 c=df_w["YEAR"],
                 cmap="viridis",
                 s=80)

plt.colorbar(sc, label="Year")
plt.xlabel("Observed Winter MB")
plt.ylabel("Predicted Winter MB")
plt.title("Nigardsbreen – Winter MB (Observed vs Predicted)")
plt.axline((0, 0), slope=1, linestyle="--", color="gray")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# --- your selection logic (as you already have it) ---
ids_weird_points = df_w[(df_w.target < 0.5) & (df_w.target > -0.5)].ID.values
point_ids_weird_points = df_holdout_NOR[df_holdout_NOR.ID.isin(
    ids_weird_points)].POINT_ID.unique()
df_outliers = data_NOR_holdout[data_NOR_holdout.POINT_ID.isin(
    point_ids_weird_points)]

# coordinates to plot
coords = df_outliers[['POINT_LAT', 'POINT_LON', 'POINT_ID']].drop_duplicates()

# --- open RGI zarr and plot raster ---
path_RGI = cfg.dataPath + 'RGI_v6/RGI_08_Scandinavia/xr_masked_grids/'
rgi_nigardsbreen = df_outliers.RGIId.unique()[0]

ds = xr.open_zarr(path_RGI + f"{rgi_nigardsbreen}.zarr")

# choose what you want underlay: masked_aspect (as you did) or masked_dem if available
da = ds["masked_aspect"]  # or ds["masked_dem"] if that's your DEM variable

fig, ax = plt.subplots(figsize=(10, 8))

# xarray plot on the axis (important: pass ax=...)
da.plot(ax=ax, add_colorbar=True)

# --- overlay points (lon/lat) ---
ax.scatter(coords["POINT_LON"].values,
           coords["POINT_LAT"].values,
           s=60,
           marker="o",
           facecolors="none",
           edgecolors="red",
           linewidths=1.8,
           zorder=10,
           label=f"Weird points (n={len(coords)})")

# Optional: annotate point IDs (can get cluttered)
for _, r in coords.iterrows():
    ax.text(r["POINT_LON"],
            r["POINT_LAT"],
            str(r["POINT_ID"]),
            fontsize=8,
            ha="left",
            va="bottom",
            zorder=11)

ax.set_title(f"{rgi_nigardsbreen} — Outlier points over {da.name}")
ax.legend(loc="upper right")
plt.tight_layout()
plt.show()

In [None]:
df_outliers

In [None]:
weird_points_df = df_w[(df_w.target < 0.5) & (df_w.target > -0.5)]
plt.figure(figsize=(8, 6))

sc = plt.scatter(weird_points_df["target"],
                 weird_points_df["pred"],
                 c=weird_points_df["YEAR"],
                 cmap="viridis",
                 s=80)

plt.colorbar(sc, label="Year")
plt.xlabel("Observed Winter MB")
plt.ylabel("Predicted Winter MB")
plt.title("Nigardsbreen – Winter MB (Observed vs Predicted)")
plt.axline((0, 0), slope=1, linestyle="--", color="gray")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()