In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from typing import Optional, Iterable, Dict, List
import os

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.EuropeTFConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cross-regional modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"]

In [None]:
# run it
summarize_and_plot_all_regions(dfs)

In [None]:
# run it
plot_mb_distributions_all_regions(dfs)

### Monthly datasets:

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

    vois_climate = [
        "t2m",
        "tp",
        "slhf",
        "sshf",
        "ssrd",
        "fal",
        "str",
    ]

vois_topographical = ["aspect", "slope", "svf"]

In [None]:
def build_crossregional_df_ceu_with_ch(dfs: dict) -> pd.DataFrame:
    """
    Concatenate all stake dataframes in `dfs` into one Europe-wide dataframe.

    Expects:
      - Each df has at least columns: GLACIER, YEAR, ID, PERIOD, MONTHS, POINT_BALANCE
      - Central Europe df includes SOURCE_CODE identifying CH/FR/IT_AT etc.

    Returns
    -------
    pd.DataFrame
        Combined dataframe (all rows across all RGI regions).
    """
    frames = []
    for rid, df in dfs.items():
        if df is None or len(df) == 0:
            logging.warning(f"RGI {rid}: empty, skipping in concat.")
            continue
        frames.append(df)

    if not frames:
        raise ValueError("No non-empty dataframes in dfs.")

    d_all = pd.concat(frames, ignore_index=True)
    return d_all


def compute_crossregional_test_glaciers(
    df_all: pd.DataFrame,
    ch_code: str = "CH",
    source_col: str = "SOURCE_CODE",
    glacier_col: str = "GLACIER",
):
    """
    Train glaciers = all glaciers with SOURCE_CODE == CH
    Test glaciers  = all glaciers with SOURCE_CODE != CH

    Returns
    -------
    (train_glaciers, test_glaciers) : (list[str], list[str])
    """
    if source_col not in df_all.columns:
        raise ValueError(
            f"Missing column {source_col}. Needed to separate CH vs others.")
    if glacier_col not in df_all.columns:
        raise ValueError(f"Missing column {glacier_col}.")

    ch_gl = sorted(df_all.loc[df_all[source_col] == ch_code,
                              glacier_col].dropna().unique())
    non_ch_gl = sorted(df_all.loc[df_all[source_col] != ch_code,
                                  glacier_col].dropna().unique())

    if not ch_gl:
        raise ValueError("No CH glaciers found (SOURCE_CODE=='CH').")
    if not non_ch_gl:
        raise ValueError("No non-CH glaciers found (SOURCE_CODE!='CH').")

    logging.info(
        f"Cross-regional split: CH train glaciers={len(ch_gl)}, non-CH test glaciers={len(non_ch_gl)}"
    )
    return ch_gl, non_ch_gl


def prepare_monthly_df_crossregional_CH_to_EU(
    cfg,
    dfs,
    paths,
    vois_climate,
    vois_topographical,
    run_flag=True,  # True recompute, False load
    region_name="XREG_CH_TO_EU",
    region_id=11,  # arbitrary/int tag used by your pipeline; keep 11 or 0
    csv_subfolder="CrossRegional/CH_to_Europe/csv",
):
    """
    Build ONE monthly-prepped dataset:
      - data = concatenation of all Europe sources
      - train = CH glaciers
      - test  = all non-CH glaciers

    Returns
    -------
    res : dict
        Same output dict as prepare_monthly_dfs_with_padding (df_train/df_test/aug/etc.)
    split_info : dict
        {"train_glaciers": [...], "test_glaciers": [...]}
    """

    # 1) Concatenate all raw stake rows
    df_all = build_crossregional_df_ceu_with_ch(dfs)

    # 2) Define test glaciers: all non-CH
    train_glaciers, test_glaciers = compute_crossregional_test_glaciers(
        df_all, ch_code="CH")

    # 3) Choose an output folder for this experiment
    paths_ = paths.copy()
    paths_["csv_path"] = os.path.join(cfg.dataPath, path_PMB_WGMS_csv,
                                      csv_subfolder)
    os.makedirs(paths_["csv_path"], exist_ok=True)

    logging.info(
        f"Preparing cross-regional monthlies: {region_name} "
        f"(run_flag={run_flag}) | train(CH)={len(train_glaciers)} | test(non-CH)={len(test_glaciers)}"
    )

    res = prepare_monthly_dfs_with_padding(
        cfg=cfg,
        df_region=df_all,
        region_name=region_name,
        region_id=int(region_id),
        paths=paths_,
        test_glaciers=test_glaciers,  # test = all non-CH glaciers
        vois_climate=vois_climate,
        vois_topographical=vois_topographical,
        run_flag=run_flag,
    )

    return res, {
        "train_glaciers": train_glaciers,
        "test_glaciers": test_glaciers
    }

In [None]:
# load all stake dfs
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# prepare monthlies (recompute or load)
res_xreg, split_info = prepare_monthly_df_crossregional_CH_to_EU(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,  # load if already computed
)

df_train = res_xreg["df_train"]
df_test = res_xreg["df_test"]

print("Train glaciers (CH):", len(split_info["train_glaciers"]))
print("Test glaciers (non-CH):", len(split_info["test_glaciers"]))
print("Train rows:", len(df_train), "Test rows:", len(df_test))

#### Finetuning glaciers:

In [None]:
import numpy as np
import pandas as pd


def pick_glaciers_by_row_fraction(
    df_test: pd.DataFrame,
    region_code: str,
    target_frac: float,
    source_col: str = "SOURCE_CODE",
    glacier_col: str = "GLACIER",
    seed: int = 42,
    method: str = "greedy_small_first",
    min_rows_per_glacier: int = 1,
):
    """
    Select glaciers whose df_test row counts sum to ~target_frac of total rows for region_code.

    method:
      - "greedy_small_first": sorts glaciers by row count ascending, then accumulates
        (often best for small targets like 5% because it can finely tune)
      - "greedy_large_first": sorts descending, then accumulates (often fine for 50%)
      - "shuffle_then_greedy": shuffle (seeded), then accumulate (stochastic but reproducible)

    Returns
    -------
    selected_glaciers : list[str]
    summary : dict with totals and achieved fraction
    per_glacier_counts : pd.Series of counts (for inspection)
    """
    df_reg = df_test.loc[df_test[source_col] == region_code].copy()
    if df_reg.empty:
        raise ValueError(
            f"No rows in df_test for region '{region_code}' (source_col={source_col})."
        )

    counts = df_reg.groupby(glacier_col).size().sort_values(ascending=False)

    # optional: remove tiny glaciers (if you want)
    counts = counts[counts >= min_rows_per_glacier]
    if counts.empty:
        raise ValueError(
            f"After filtering min_rows_per_glacier={min_rows_per_glacier}, no glaciers remain for {region_code}."
        )

    total_rows = int(counts.sum())
    target_rows = int(round(target_frac * total_rows))

    # order glaciers for greedy
    if method == "greedy_small_first":
        ordered = counts.sort_values(ascending=True)
    elif method == "greedy_large_first":
        ordered = counts.sort_values(ascending=False)
    elif method == "shuffle_then_greedy":
        rng = np.random.default_rng(seed)
        idx = counts.index.to_numpy()
        rng.shuffle(idx)
        ordered = counts.loc[idx]
    else:
        raise ValueError(f"Unknown method='{method}'")

    selected = []
    s = 0

    # greedy accumulate
    for gl, n in ordered.items():
        # if we already hit/exceeded target, decide whether adding this glacier helps or hurts
        if s >= target_rows:
            # check if adding would improve closeness
            cur_err = abs(s - target_rows)
            new_err = abs((s + int(n)) - target_rows)
            if new_err < cur_err:
                selected.append(gl)
                s += int(n)
            break
        else:
            selected.append(gl)
            s += int(n)

    # small local improvement: try swapping one glacier if it improves error (optional, cheap)
    # (helps especially near 50% targets)
    selected_set = set(selected)
    not_selected = [g for g in counts.index if g not in selected_set]

    best_err = abs(s - target_rows)
    best_swap = None

    # limit search for speed (still usually enough)
    cand_sel = selected[:min(len(selected), 40)]
    cand_nsel = not_selected[:min(len(not_selected), 60)]

    sel_counts = counts.loc[cand_sel]
    nsel_counts = counts.loc[cand_nsel]

    for g_out, n_out in sel_counts.items():
        for g_in, n_in in nsel_counts.items():
            s2 = s - int(n_out) + int(n_in)
            err2 = abs(s2 - target_rows)
            if err2 < best_err:
                best_err = err2
                best_swap = (g_out, g_in, s2)

    if best_swap is not None:
        g_out, g_in, s2 = best_swap
        selected = [g for g in selected if g != g_out] + [g_in]
        s = int(s2)

    achieved_frac = s / total_rows if total_rows > 0 else np.nan

    summary = {
        "region": region_code,
        "target_frac": float(target_frac),
        "total_rows_region": total_rows,
        "target_rows": target_rows,
        "selected_rows": int(s),
        "achieved_frac": float(achieved_frac),
        "achieved_pct": float(100 * achieved_frac),
        "n_glaciers_total": int(counts.shape[0]),
        "n_glaciers_selected": int(len(selected)),
        "abs_row_error": int(abs(s - target_rows)),
    }

    return selected, summary, counts

In [None]:
# SJM 5%: small-first greedy usually gives best control for a small fraction
SJM_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="SJM",
    target_frac=0.05,
    method="greedy_small_first",
    seed=42,
)

# SJM 50%: large-first or small-first both work; I’d start with large-first
SJM_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="SJM",
    target_frac=0.50,
    method="greedy_large_first",
    seed=42,
)

print("SJM 5% summary:", sjm5_info)
print("SJM 50% summary:", sjm50_info)

print("\nSJM_5pct glaciers:", SJM_5pct)
print("\nSJM_50pct glaciers:", SJM_50pct)

In [None]:
# ISL 5%: small-first greedy usually gives best control for a small fraction
ISL_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="ISL",
    target_frac=0.05,
    method="greedy_small_first",
    seed=42,
)

# ISL 50%: large-first or small-first both work; I’d start with large-first
ISL_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="ISL",
    target_frac=0.50,
    method="greedy_large_first",
    seed=42,
)

print("ISL 5% summary:", sjm5_info)
print("ISL 50% summary:", sjm50_info)

print("\nISL_5pct glaciers:", ISL_5pct)
print("\nISL_50pct glaciers:", ISL_50pct)

In [None]:
# FR 5%: small-first greedy usually gives best control for a small fraction
FR_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="FR",
    target_frac=0.05,
    method="greedy_small_first",
    seed=42,
)

# FR 50%: large-first or small-first both work; I’d start with large-first
FR_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="FR",
    target_frac=0.50,
    method="greedy_large_first",
    seed=42,
)

print("FR 5% summary:", sjm5_info)
print("FR 50% summary:", sjm50_info)

print("\nFR_5pct glaciers:", FR_5pct)
print("\nFR_50pct glaciers:", FR_50pct)

In [None]:
# IT_AT 5%: small-first greedy usually gives best control for a small fraction
IT_AT_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="IT_AT",
    target_frac=0.05,
    method="greedy_small_first",
    seed=42,
)

# IT_AT 50%: large-first or small-first both work; I’d start with large-first
IT_AT_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="IT_AT",
    target_frac=0.50,
    method="greedy_large_first",
    seed=42,
)

print("IT_AT 5% summary:", sjm5_info)
print("IT_AT 50% summary:", sjm50_info)

print("\nIT_AT_5pct glaciers:", IT_AT_5pct)
print("\nIT_AT_50pct glaciers:", IT_AT_50pct)

In [None]:
# NOR 5%: small-first greedy usually gives best control for a small fraction
NOR_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="NOR",
    target_frac=0.05,
    method="greedy_small_first",
    seed=42,
)

# NOR 50%: large-first or small-first both work; I’d start with large-first
NOR_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="NOR",
    target_frac=0.50,
    method="greedy_large_first",
    seed=42,
)

print("NOR 5% summary:", sjm5_info)
print("NOR 50% summary:", sjm50_info)

print("\nNOR_5pct glaciers:", NOR_5pct)
print("\nNOR_50pct glaciers:", NOR_50pct)

In [None]:
# Norway
# 50% split
FT_50PCT_NOR = [
    'Nigardsbreen', 'Aalfotbreen', 'Engabreen', 'Storsteinsfjellbreen',
    'Cainhavarre'
]

# 5% split
FT_5PCT_NOR = [
    'Moesevassbrea', 'Vetlefjordbreen', 'Juvfonne', 'Graasubreen',
    'Hellstugubreen', 'Storglombreen N', 'Blabreen', 'Ruklebreen',
    'Vestre Memurubreen', 'Cainhavarre', 'Bondhusbrea'
]

# France
# 5% split
FT_5PCT_FR = ['Grands Montets', 'Sarennes', 'Talefre', 'Leschaux']

# 50% split
FT_50PCT_FR = ['Argentiere', 'Gebroulaz']

# IT-AT
FT_5PCT_IT_AT = [
    'CIARDONEY', 'CARESER CENTRALE', 'CAMPO SETT.', 'ZETTALUNITZ/MULLWITZ K.',
    'HALLSTAETTER G.', 'VENEDIGER K.', 'SURETTA MERIDIONALE', 'GOLDBERG K.',
    'CARESER OCCIDENTALE', 'GRAND ETRET', 'LUPO'
]

FT_50PCT_IT_AT = [
    'HINTEREIS F.', 'MALAVALLE (VEDR. DI) / UEBELTALF.',
    'LUNGA (VEDRETTA) / LANGENF.', 'RIES OCC. (VEDR. DI) / RIESERF. WESTL.'
]

# Iceland
# 5% split
FT_5PCT_ISL = [
    'RGI60-06.00306', 'RGI60-06.00296', 'RGI60-06.00479', 'RGI60-06.00425',
    'RGI60-06.00445', 'RGI60-06.00474', 'RGI60-06.00542',
    'Reykjafjardarjoekull', 'RGI60-06.00350', 'RGI60-06.00342',
    'RGI60-06.00301', 'RGI60-06.00422', 'RGI60-06.00320', 'RGI60-06.00359',
    'RGI60-06.00349', 'RGI60-06.00409', 'RGI60-06.00413', 'RGI60-06.00411',
    'Oeldufellsjoekull', 'RGI60-06.00476', 'RGI60-06.00549', 'RGI60-06.00228',
    'RGI60-06.00303', 'Kaldalonsjoekull', 'RGI60-06.00328', 'RGI60-06.00541',
    'Slettjoekull West', 'RGI60-06.00232', 'RGI60-06.00305'
]

# 50% split
FT_50PCT_ISL = [
    'RGI60-06.00238', 'Bruarjoekull', 'Skeidararjoekull',
    'Thjorsarjoekull (Hofsjoekull E)', 'Sidujoekull/Skaftarjoekull',
    'Hagafellsjoekull West', 'RGI60-06.00305'
]

# Svalbard
# 15% split
FT_5PCT_SJM = ['WERENSKIOLDBREEN']

# 5% split
FT_50PCT_SJM = ['GROENFJORD E', 'WERENSKIOLDBREEN']

FT_GLACIERS = {
    "FR": {
        "5pct": FT_5PCT_FR,
        "50pct": FT_50PCT_FR
    },
    "IT_AT": {
        "5pct": FT_5PCT_IT_AT,
        "50pct": FT_50PCT_IT_AT
    },
    "NOR": {
        "5pct": FT_5PCT_NOR,
        "50pct": FT_50PCT_NOR
    },
    "ISL": {
        "5pct": FT_5PCT_ISL,
        "50pct": FT_50PCT_ISL
    },
    "SJM": {
        "5pct": FT_5PCT_SJM,
        "50pct": FT_50PCT_SJM
    }
}

In [None]:
def verify_row_percentage(df_test,
                          FT_GLACIERS,
                          source_col="SOURCE_CODE",
                          glacier_col="GLACIER"):

    results = []

    for region, splits in FT_GLACIERS.items():

        df_reg = df_test[df_test[source_col] == region]

        total_rows = len(df_reg)
        if total_rows == 0:
            print(f"{region}: no rows in df_test")
            continue

        for split_name, glacier_list in splits.items():

            df_ft = df_reg[df_reg[glacier_col].isin(glacier_list)]
            ft_rows = len(df_ft)

            pct = 100 * ft_rows / total_rows

            results.append({
                "region": region,
                "split": split_name,
                "rows_total_region": total_rows,
                "rows_ft": ft_rows,
                "pct_rows": pct,
            })

            print(f"{region} | {split_name}: "
                  f"{ft_rows}/{total_rows} rows = {pct:.2f}%")

    return pd.DataFrame(results)


df_row_check = verify_row_percentage(df_test, FT_GLACIERS)
df_row_check

In [None]:
for reg in FT_GLACIERS.keys():
    gls = sorted(df_test.loc[df_test["SOURCE_CODE"] == reg,
                             "GLACIER"].unique())
    print(reg, "unique glaciers in df_test:", len(gls))

#### Feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
def plot_tsne_overlap_xreg_from_single_res(
        res_xreg: dict,
        cfg,
        STATIC_COLS,
        MONTHLY_COLS,
        group_col: str = "SOURCE_CODE",
        ch_code: str = "CH",
        use_aug: bool = False,  # True -> df_train_aug/df_test_aug
        n_iter: int = 1000,
        only_codes=None,  # e.g. ["IT_AT", "FR"]
        skip_codes=None,  # e.g. ["CH"]
):
    """
    For XREG where train=CH and test=all non-CH inside ONE monthly result dict:

      - df_ch = res_xreg[df_train*]
      - df_test_all = res_xreg[df_test*]
      - split df_test_all by SOURCE_CODE and plot CH vs each code

    Returns dict: code -> figure
    """
    only_codes = {c.upper() for c in (only_codes or [])} or None
    skip_codes = {c.upper() for c in (skip_codes or [])}
    skip_codes.add(ch_code.upper())

    # pick which dfs
    if use_aug:
        df_ch = res_xreg.get("df_train_aug")
        df_test_all = res_xreg.get("df_test_aug")
        label_df = "(*_aug)"
    else:
        df_ch = res_xreg.get("df_train")
        df_test_all = res_xreg.get("df_test")
        label_df = ""

    if df_ch is None or len(df_ch) == 0:
        raise ValueError(f"df_train{label_df} missing/empty in res_xreg.")
    if df_test_all is None or len(df_test_all) == 0:
        raise ValueError(f"df_test{label_df} missing/empty in res_xreg.")

    if group_col not in df_test_all.columns:
        raise ValueError(
            f"'{group_col}' not found in df_test{label_df}. Needed to split by region."
        )
    if group_col not in df_ch.columns:
        # not fatal, but helps sanity-check
        print(
            f"[warn] '{group_col}' not in df_train{label_df}. That's OK for CH reference."
        )

    # palette
    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    custom_palette = {"Train": color_dark_blue, "Test": "#b2182b"}

    # codes present in test
    codes_present = sorted(c for c in df_test_all[group_col].dropna().astype(
        str).str.upper().unique() if c not in skip_codes)

    if only_codes is not None:
        codes_present = [c for c in codes_present if c in only_codes]

    figs = {}
    for code in codes_present:
        df_other = df_test_all[df_test_all[group_col].astype(str).str.upper()
                               == code].copy()
        if len(df_other) == 0:
            continue

        print(
            f"Plotting XREG t-SNE: CH(train n={len(df_ch)}) vs {code}(test n={len(df_other)})"
        )

        fig = plot_tsne_overlap(
            data_train=df_ch,
            data_test=df_other,
            STATIC_COLS=STATIC_COLS,
            MONTHLY_COLS=MONTHLY_COLS,
            sublabels=("a", "b", "c"),
            label_fmt="({})",
            label_xy=(0.02, 0.98),
            label_fontsize=14,
            n_iter=n_iter,
            random_state=cfg.seed,
            custom_palette=custom_palette,
        )
        fig.suptitle(f"XREG overlap: CH vs {code}", fontsize=14)
        figs[code] = fig

    return figs


# res_xreg is the ONE dict from your cross-regional monthly prep
figs_by_code = plot_tsne_overlap_xreg_from_single_res(
    res_xreg=res_xreg,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=False,  # or True if you want *_aug
    n_iter=1000,
    # only_codes=["IT_AT"],  # optional
)

In [None]:
import os


def plot_feature_kde_overlap_xreg_ch_vs_codes(
    res_xreg: dict,
    cfg,
    features,
    group_col: str = "SOURCE_CODE",
    ch_code: str = "CH",
    use_aug: bool = False,  # True -> df_train_aug/df_test_aug
    only_codes=None,  # e.g. ["IT_AT", "FR"]
    skip_codes=None,  # e.g. ["CH"]
    output_dir=None,  # e.g. "figures/xreg_kde"
    include_ch_in_title: bool = True,
):
    """
    Plot KDE-based feature overlap for XREG: CH vs each SOURCE_CODE subset.

    Uses:
      - CH reference: res_xreg["df_train"] (or "_aug" if use_aug)
      - Other region: subset of res_xreg["df_test"] by SOURCE_CODE

    Parameters
    ----------
    res_xreg : dict
        Output dict from prepare_monthly_df_crossregional_CH_to_EU (or similar),
        containing df_train/df_test and optionally df_train_aug/df_test_aug.
        df_test must contain `group_col` (SOURCE_CODE).
    cfg : object
        Used only for consistent output naming if desired (optional).
    features : list[str]
        Feature columns to plot.
    group_col : str
        Column to split test set by (default: "SOURCE_CODE").
    ch_code : str
        Code identifying CH (default: "CH").
    use_aug : bool
        If True uses df_train_aug/df_test_aug.
    only_codes : list[str] or None
        If given, only plot these codes.
    skip_codes : list[str] or None
        Codes to skip (CH is always skipped by default).
    output_dir : str or None
        If set, saves one PNG per code into this directory.
    include_ch_in_title : bool
        Adds CH vs CODE title on each figure.

    Returns
    -------
    dict
        code -> matplotlib Figure
    """
    # palette (reuse your consistent colors)
    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    palette = {
        "Train": color_dark_blue,
        "Test": "#b2182b"
    }  # Train=CH, Test=Other

    ch_code = str(ch_code).upper()
    only_set = {c.upper() for c in only_codes} if only_codes else None
    skip_set = {c.upper() for c in (skip_codes or [])}
    skip_set.add(ch_code)

    if use_aug:
        df_ch = res_xreg.get("df_train_aug")
        df_test_all = res_xreg.get("df_test_aug")
        suffix = "_aug"
    else:
        df_ch = res_xreg.get("df_train")
        df_test_all = res_xreg.get("df_test")
        suffix = ""

    if df_ch is None or len(df_ch) == 0:
        raise ValueError(f"Missing/empty df_train{suffix} in res_xreg.")
    if df_test_all is None or len(df_test_all) == 0:
        raise ValueError(f"Missing/empty df_test{suffix} in res_xreg.")
    if group_col not in df_test_all.columns:
        raise ValueError(f"'{group_col}' not found in df_test{suffix}.")

    codes = sorted(
        df_test_all[group_col].dropna().astype(str).str.upper().unique())
    codes = [c for c in codes if c not in skip_set]
    if only_set is not None:
        codes = [c for c in codes if c in only_set]

    if output_dir:
        out_abs = os.path.join(cfg.dataPath, output_dir) if hasattr(
            cfg, "dataPath") else output_dir
        os.makedirs(out_abs, exist_ok=True)
    else:
        out_abs = None

    figs = {}

    for code in codes:
        df_other = df_test_all[df_test_all[group_col].astype(str).str.upper()
                               == code].copy()
        if len(df_other) == 0:
            continue

        print(
            f"Plotting XREG KDE: CH(train n={len(df_ch)}) vs {code}(test n={len(df_other)})"
        )

        fig = plot_feature_kde_overlap(
            df_train=df_ch,
            df_test=df_other,
            features=features,
            palette=palette,
            outfile=None,  # save here instead (so we control naming)
        )

        if include_ch_in_title:
            fig.suptitle(f"XREG feature overlap: CH vs {code}", fontsize=14)
            fig.tight_layout()

        if out_abs:
            out_png = os.path.join(
                out_abs, f"xreg_kde_overlap_CH_vs_{code}{suffix}.png")
            fig.savefig(out_png, dpi=300, bbox_inches="tight")

        figs[code] = fig

    return figs


In [None]:
FEATURES = MONTHLY_COLS + STATIC_COLS + ["POINT_BALANCE"]

figs_kde = plot_feature_kde_overlap_xreg_ch_vs_codes(
    res_xreg=res_xreg,
    cfg=cfg,
    features=FEATURES,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=True,  # usually best for feature overlap
    # only_codes=["IT_AT", "FR"],    # optional
    output_dir="figures/xreg_kde",  # optional
)


## LSTM model
### LSTM datasets:

In [None]:
def _check_for_nans(key,
                    df_loss,
                    df_full,
                    monthly_cols,
                    static_cols,
                    strict=True):
    """
    Checks for NaNs/Infs in features and targets.
    Raises ValueError if strict=True, otherwise prints warning.
    """
    feat_cols = [
        c for c in (monthly_cols + static_cols) if c in df_full.columns
    ]

    # --- feature NaNs ---
    n_nan_feat = df_full[feat_cols].isna().sum().sum()
    n_inf_feat = np.isinf(df_full[feat_cols].to_numpy(dtype="float64",
                                                      copy=False)).sum()

    # --- target NaNs ---
    n_nan_target = df_loss["POINT_BALANCE"].isna().sum()
    n_inf_target = np.isinf(df_loss["POINT_BALANCE"].to_numpy(
        dtype="float64", copy=False)).sum()

    if any([n_nan_feat, n_inf_feat, n_nan_target, n_inf_target]):

        msg = (f"[{key}] Data integrity issue:\n"
               f"  Feature NaNs: {n_nan_feat}\n"
               f"  Feature Infs: {n_inf_feat}\n"
               f"  Target  NaNs: {n_nan_target}\n"
               f"  Target  Infs: {n_inf_target}")

        if strict:
            raise ValueError(msg)
        else:
            warnings.warn(msg)


def _lstm_cache_paths(cfg, key: str, cache_dir: str):
    out_dir = os.path.join(cache_dir)
    os.makedirs(out_dir, exist_ok=True)
    train_p = os.path.join(out_dir, f"{key}_train.joblib")
    test_p = os.path.join(out_dir, f"{key}_test.joblib")
    split_p = os.path.join(out_dir, f"{key}_split.joblib")
    return train_p, test_p, split_p

In [None]:
import os
import joblib
import logging


# ------------------------------------------------------------
# 1) Build/load a PRISTINE dataset only (no scalers inside)
# ------------------------------------------------------------
def build_or_load_lstm_dataset_only(
        cfg,
        key: str,
        df_loss,
        df_full,
        months_head_pad,
        months_tail_pad,
        MONTHLY_COLS,
        STATIC_COLS,
        cache_dir="logs/LSTM_cache",
        force_recompute=False,
        normalize_target=True,
        expect_target=True,
        strict_nan=True,
        kind="dataset",  # keep kind to avoid duplicate functions; default "dataset"
):
    out_dir = os.path.join(cache_dir)
    os.makedirs(out_dir, exist_ok=True)
    p = os.path.join(out_dir, f"{key}_{kind}.joblib")

    # ---- Load cached (must be pristine) ----
    if (not force_recompute) and os.path.exists(p):
        ds = joblib.load(p)
        if (ds.month_mean is not None) or (ds.static_mean
                                           is not None) or (ds.y_mean
                                                            is not None):
            raise ValueError(
                f"{key}_{kind}: cached dataset already has scalers set. "
                "Cache should store pristine datasets only.")
        return ds

    # ---- Build fresh ----
    _check_for_nans(
        key,
        df_loss=df_loss,
        df_full=df_full,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        strict=strict_nan,
    )

    mbm.utils.seed_all(cfg.seed)

    ds = build_combined_LSTM_dataset(
        df_loss=df_loss,
        df_full=df_full,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    # sanity: ensure pristine before caching
    if (ds.month_mean is not None) or (ds.static_mean
                                       is not None) or (ds.y_mean is not None):
        raise ValueError(
            f"{key}_{kind}: newly built dataset unexpectedly has scalers set.")

    joblib.dump(ds, p, compress=3)
    return ds


# ------------------------------------------------------------
# 2) Transfer-learning slicing (no scaling logic here)
# ------------------------------------------------------------
def make_res_transfer_learning(
    res_xreg: dict,
    target_code: str,
    ft_glaciers: list,
    source_col="SOURCE_CODE",
    glacier_col="GLACIER",
):
    """
    Returns:
      res_pretrain: CH-only (df_train/df_train_aug + pads)
      res_ft: target finetune subset (df_train/df_train_aug + pads)
      res_test: target holdout (df_test/df_test_aug + pads)
    """
    res_pretrain = {
        "df_train": res_xreg["df_train"],
        "df_train_aug": res_xreg["df_train_aug"],
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    df_t_all = res_xreg["df_test"]
    df_t_all_aug = res_xreg["df_test_aug"]

    df_target = df_t_all.loc[df_t_all[source_col] == target_code].copy()
    df_target_aug = df_t_all_aug.loc[df_t_all_aug[source_col] ==
                                     target_code].copy()

    df_ft = df_target.loc[df_target[glacier_col].isin(ft_glaciers)].copy()
    df_ft_aug = df_target_aug.loc[df_target_aug[glacier_col].isin(
        ft_glaciers)].copy()

    df_hold = df_target.loc[~df_target[glacier_col].isin(ft_glaciers)].copy()
    df_hold_aug = df_target_aug.loc[~df_target_aug[glacier_col].
                                    isin(ft_glaciers)].copy()

    res_ft = {
        "df_train": df_ft,
        "df_train_aug": df_ft_aug,
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    res_test = {
        "df_test": df_hold,
        "df_test_aug": df_hold_aug,
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    return res_pretrain, res_ft, res_test


# ------------------------------------------------------------
# 3) Build/load CH train dataset + split + SCALER DONOR (Option 2)
# ------------------------------------------------------------
def build_or_load_lstm_train_only(
    cfg,
    key_train: str,
    res_train: dict,  # must contain df_train, df_train_aug, pads
    MONTHLY_COLS,
    STATIC_COLS,
    val_ratio=0.2,
    cache_dir="logs/LSTM_cache",
    force_recompute=False,
    normalize_target=True,
    expect_target=True,
    strict_nan=True,
):
    train_p, _, split_p = _lstm_cache_paths(cfg,
                                            key_train,
                                            cache_dir=cache_dir)
    scaler_p = os.path.join(cache_dir, f"{key_train}_scalers.joblib")

    # ---- Load cached assets (train ds must be pristine; scalers ds must have scalers) ----
    if (not force_recompute) and all(
            os.path.exists(p) for p in [train_p, split_p, scaler_p]):
        ds_train = joblib.load(train_p)
        split = joblib.load(split_p)
        ds_scalers = joblib.load(scaler_p)

        # guards
        if (ds_train.month_mean
                is not None) or (ds_train.static_mean
                                 is not None) or (ds_train.y_mean is not None):
            raise ValueError(
                f"{key_train}: cached TRAIN dataset has scalers set. "
                "train_p cache must store pristine dataset only.")
        if (ds_scalers.month_mean is None) or (ds_scalers.static_mean
                                               is None) or (ds_scalers.y_mean
                                                            is None):
            raise ValueError(
                f"{key_train}: cached SCALER donor is missing scalers.")

        return ds_train, split["train_idx"], split["val_idx"], ds_scalers

    # ---- Build fresh ----
    df_train = res_train["df_train"]
    df_train_aug = res_train["df_train_aug"]
    months_head_pad = res_train["months_head_pad"]
    months_tail_pad = res_train["months_tail_pad"]

    _check_for_nans(
        key_train,
        df_loss=df_train,
        df_full=df_train_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        strict=strict_nan,
    )

    mbm.utils.seed_all(cfg.seed)

    ds_train = build_combined_LSTM_dataset(
        df_loss=df_train,
        df_full=df_train_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    # split indices
    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=val_ratio, seed=cfg.seed)

    # ---- NEW: create scaler donor and fit scalers on CH TRAIN split only ----
    ds_scalers = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train)
    ds_scalers.fit_scalers(train_idx)

    # ---- Cache ----
    joblib.dump(ds_train, train_p, compress=3)
    joblib.dump({
        "train_idx": train_idx,
        "val_idx": val_idx
    },
                split_p,
                compress=3)
    joblib.dump(ds_scalers, scaler_p, compress=3)

    return ds_train, train_idx, val_idx, ds_scalers

In [None]:
def build_transfer_learning_assets(
    cfg,
    res_xreg,
    FT_GLACIERS,
    MONTHLY_COLS,
    STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL",
    force_recompute=False,
    val_ratio=0.2,
):
    logging.info("\n" + "=" * 70)
    logging.info("TRANSFER LEARNING ASSET PREPARATION")
    logging.info("=" * 70)
    logging.info(f"Cache directory: {cache_dir}")
    logging.info(f"Regions in FT_GLACIERS: {list(FT_GLACIERS.keys())}")

    assets = {}

    # ------------------------------------------------------------------
    # 1) CH PRETRAIN DATASET (shared across all TL experiments)
    # ------------------------------------------------------------------
    key_train = "TL_CH_TRAIN"

    logging.info("\n--- CH PRETRAIN DATASET ---")
    logging.info(f"Cache key: {key_train}")
    logging.info(f"Force recompute: {force_recompute}")

    res_train = {
        "df_train": res_xreg["df_train"],
        "df_train_aug": res_xreg["df_train_aug"],
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    logging.info(f"CH train rows: {len(res_train['df_train'])} | "
                 f"Aug rows: {len(res_train['df_train_aug'])}")

    # ---- Option 2: also returns ds_ch_scalers (cached) ----
    ds_ch, train_idx, val_idx, ds_ch_scalers = build_or_load_lstm_train_only(
        cfg=cfg,
        key_train=key_train,
        res_train=res_train,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        val_ratio=val_ratio,
        cache_dir=cache_dir,
        force_recompute=force_recompute,
    )

    # IMPORTANT: do NOT fit scalers on ds_ch here anymore
    # ds_ch_scalers is the scaler donor; ds_ch stays pristine.

    logging.info(f"CH dataset size (sequences): {len(ds_ch)} | "
                 f"Train split: {len(train_idx)} | Val split: {len(val_idx)}")

    # ------------------------------------------------------------------
    # 2) PER REGION × SPLIT
    # ------------------------------------------------------------------
    for reg, splits in FT_GLACIERS.items():

        logging.info("\n" + "-" * 60)
        logging.info(f"TARGET REGION: {reg}")
        logging.info("-" * 60)

        for split_name, ft_gls in splits.items():

            exp_key = f"TL_CH_to_{reg}_{split_name}"

            logging.info("\n" + "-" * 40)
            logging.info(f"Experiment: {exp_key}")
            logging.info(f"Finetune glacier count: {len(ft_gls)}")

            # ----------------------------------------------------------
            # Slice finetune + holdout
            # ----------------------------------------------------------
            res_pre, res_ft, res_test = make_res_transfer_learning(
                res_xreg=res_xreg,
                target_code=reg,
                ft_glaciers=ft_gls,
            )

            logging.info(f"FT rows: {len(res_ft['df_train'])} | "
                         f"FT aug rows: {len(res_ft['df_train_aug'])}")

            logging.info(f"Holdout rows: {len(res_test['df_test'])} | "
                         f"Holdout aug rows: {len(res_test['df_test_aug'])}")

            if len(res_ft["df_train"]) == 0:
                logging.warning(f"{exp_key}: EMPTY FINETUNE SET -> skipping.")
                continue

            # ----------------------------------------------------------
            # Finetune dataset (PRISTINE)
            # ----------------------------------------------------------
            ft_cache_key = f"{exp_key}_FT"
            logging.info(f"Finetune cache key: {ft_cache_key}")

            ds_ft = build_or_load_lstm_dataset_only(
                cfg=cfg,
                key=ft_cache_key,
                df_loss=res_ft["df_train"],
                df_full=res_ft["df_train_aug"],
                months_head_pad=res_ft["months_head_pad"],
                months_tail_pad=res_ft["months_tail_pad"],
                MONTHLY_COLS=MONTHLY_COLS,
                STATIC_COLS=STATIC_COLS,
                cache_dir=cache_dir,
                force_recompute=force_recompute,
                kind="ft",
            )

            logging.info(f"Finetune dataset size (sequences): {len(ds_ft)}")

            ft_train_idx, ft_val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
                len(ds_ft), val_ratio=val_ratio, seed=cfg.seed)

            logging.info(f"FT train split: {len(ft_train_idx)} | "
                         f"FT val split: {len(ft_val_idx)}")

            # ----------------------------------------------------------
            # Holdout test dataset (PRISTINE)
            # ----------------------------------------------------------
            ds_test = None
            if len(res_test["df_test"]) > 0:

                test_cache_key = f"{exp_key}_TEST"
                logging.info(f"Holdout cache key: {test_cache_key}")

                ds_test = build_or_load_lstm_dataset_only(
                    cfg=cfg,
                    key=test_cache_key,
                    df_loss=res_test["df_test"],
                    df_full=res_test["df_test_aug"],
                    months_head_pad=res_test["months_head_pad"],
                    months_tail_pad=res_test["months_tail_pad"],
                    MONTHLY_COLS=MONTHLY_COLS,
                    STATIC_COLS=STATIC_COLS,
                    cache_dir=cache_dir,
                    force_recompute=force_recompute,
                    kind="test",
                )

                logging.info(
                    f"Holdout dataset size (sequences): {len(ds_test)}")

            else:
                logging.warning(f"{exp_key}: No holdout test set available.")

            # ----------------------------------------------------------
            # Store assets (include ds_ch_scalers!)
            # ----------------------------------------------------------
            assets[exp_key] = {
                "ds_pretrain": ds_ch,  # pristine CH dataset
                "ds_pretrain_scalers":
                ds_ch_scalers,  # <-- IMPORTANT: scaler donor
                "pretrain_train_idx": train_idx,
                "pretrain_val_idx": val_idx,
                "ds_finetune": ds_ft,  # pristine FT dataset
                "finetune_train_idx": ft_train_idx,
                "finetune_val_idx": ft_val_idx,
                "ds_test": ds_test,  # pristine test dataset
                "target_code": reg,
                "split_name": split_name,
                "ft_glaciers": ft_gls,
                "cache_keys": {
                    "pretrain": key_train,
                    "finetune": ft_cache_key,
                    "test": f"{exp_key}_TEST",
                },
            }

    logging.info("\nFinished building transfer learning assets.")
    logging.info("=" * 70 + "\n")

    return assets

In [None]:
FT_GLACIERS = {
    "NOR": {
        "5pct": FT_5PCT_NOR,
        "50pct": FT_50PCT_NOR
    },
    "FR": {
        "5pct": FT_5PCT_FR
    },
    "ISL": {
        "5pct": FT_5PCT_ISL,
        "50pct": FT_50PCT_ISL
    },
}

tl_assets = build_transfer_learning_assets(
    cfg=cfg,
    res_xreg=res_xreg,
    FT_GLACIERS=FT_GLACIERS,  # only NOR/FR/ISL as you defined
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL",
    force_recompute=True,
)

### LSTM parameters:

In [None]:
default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

### Train model:

In [None]:
def make_finetune_loaders_for_exp(
    cfg,
    tl_assets_for_key,
    batch_size_train=64,
    batch_size_val=128,
):
    ds_ft = tl_assets_for_key["ds_finetune"]
    train_idx = tl_assets_for_key["finetune_train_idx"]
    val_idx = tl_assets_for_key["finetune_val_idx"]

    # ---- NEW: scaler donor from assets ----
    ds_ch_scalers = tl_assets_for_key["ds_pretrain_scalers"]
    assert ds_ch_scalers.month_mean is not None, "CH scaler donor has no fitted scalers!"

    ds_ft_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_ft)

    # ---- apply CH scalers ----
    ds_ft_copy.set_scalers_from(ds_ch_scalers)
    ds_ft_copy.transform_inplace()

    # build loaders, no fitting
    ft_train_dl, ft_val_dl = ds_ft_copy.make_loaders(
        train_idx=train_idx,
        val_idx=val_idx,
        batch_size_train=batch_size_train,
        batch_size_val=batch_size_val,
        seed=cfg.seed,
        fit_and_transform=False,  # IMPORTANT
        shuffle_train=True,
        use_weighted_sampler=True,
    )
    return ds_ft_copy, ft_train_dl, ft_val_dl


def freeze_lstm_only(model):
    for name, p in model.named_parameters():
        if name.startswith("lstm."):
            p.requires_grad = False
        else:
            p.requires_grad = True


def unfreeze_all(model):
    for p in model.parameters():
        p.requires_grad = True


def finetune_or_load_one_TL(
    cfg,
    exp_key: str,  # e.g. "TL_CH_to_ISL_5pct"
    tl_assets_for_key: dict,
    best_params: dict,
    device,
    pretrained_ckpt_path: str,  # CH model checkpoint to start from
    models_dir="models",
    prefix="lstm_TL",
    strategy="safe",  # "safe" | "full" | "two_stage"
    force_retrain=False,
    batch_size_train=64,
    batch_size_val=128,
    epochs_safe=60,
    epochs_full=80,
    stage1_epochs=20,
    stage2_epochs=60,
    lr_safe=1e-4,
    lr_full=1e-5,
    lr_stage1=2e-4,
    lr_stage2=1e-5,
):
    os.makedirs(models_dir, exist_ok=True)
    current_date = datetime.now().strftime("%Y-%m-%d")

    out_name = f"{prefix}_{exp_key}_{strategy}_{current_date}.pt"
    out_path = os.path.join(models_dir, out_name)

    # load if exists
    if (not force_retrain) and os.path.exists(out_path):
        model = mbm.models.LSTM_MB.build_model_from_params(
            cfg, best_params, device)
        state = torch.load(out_path, map_location=device)
        model.load_state_dict(state)
        return model, out_path, None

    # build model + loss
    model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params,
                                                       device)
    loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

    # load pretrained weights
    state = torch.load(pretrained_ckpt_path, map_location=device)
    model.load_state_dict(state)

    # loaders
    ds_ft_copy, ft_train_dl, ft_val_dl = make_finetune_loaders_for_exp(
        cfg,
        tl_assets_for_key,
        batch_size_train=batch_size_train,
        batch_size_val=batch_size_val,
    )

    # overwrite if retraining
    if os.path.exists(out_path):
        os.remove(out_path)
        logging.info(f"[{exp_key}] Deleted existing TL checkpoint: {out_path}")

    # --- strategies ---
    if strategy == "safe":
        freeze_lstm_only(model)
        opt = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr_safe,
            weight_decay=best_params["weight_decay"],
        )
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=ft_train_dl,
            val_dl=ft_val_dl,
            epochs=epochs_safe,
            optimizer=opt,
            clip_val=1.0,
            loss_fn=loss_fn,
            es_patience=8,
            save_best_path=out_path,
            verbose=True,
        )

    elif strategy == "full":
        unfreeze_all(model)
        opt = torch.optim.AdamW(
            model.parameters(),
            lr=lr_full,
            weight_decay=best_params["weight_decay"],
        )
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=ft_train_dl,
            val_dl=ft_val_dl,
            epochs=epochs_full,
            optimizer=opt,
            clip_val=1.0,
            loss_fn=loss_fn,
            es_patience=10,
            save_best_path=out_path,
            verbose=True,
        )

    elif strategy == "two_stage":
        # stage 1: safe (freeze lstm)
        tmp_stage1 = out_path.replace(".pt", "_stage1_tmp.pt")

        freeze_lstm_only(model)
        opt1 = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr_stage1,
            weight_decay=best_params["weight_decay"],
        )
        model.train_loop(
            device=device,
            train_dl=ft_train_dl,
            val_dl=ft_val_dl,
            epochs=stage1_epochs,
            optimizer=opt1,
            clip_val=1.0,
            loss_fn=loss_fn,
            es_patience=5,
            save_best_path=tmp_stage1,
            verbose=True,
        )

        # load best stage1
        state = torch.load(tmp_stage1, map_location=device)
        model.load_state_dict(state)

        # stage 2: full (unfreeze all)
        unfreeze_all(model)
        opt2 = torch.optim.AdamW(
            model.parameters(),
            lr=lr_stage2,
            weight_decay=best_params["weight_decay"],
        )
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=ft_train_dl,
            val_dl=ft_val_dl,
            epochs=stage2_epochs,
            optimizer=opt2,
            clip_val=1.0,
            loss_fn=loss_fn,
            es_patience=10,
            save_best_path=out_path,
            verbose=True,
        )

        # cleanup optional
        try:
            os.remove(tmp_stage1)
        except OSError:
            pass

    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    # load best
    state = torch.load(out_path, map_location=device)
    model.load_state_dict(state)

    return model, out_path, {"history": history, "best_val": best_val}

In [None]:
def finetune_TL_models_all(
    cfg,
    tl_assets_by_key: dict,  # e.g. tl_assets["TL_CH_to_ISL_5pct"] -> {...}
    best_params: dict,
    device,
    pretrained_ckpt_path: str,
    strategies=("safe", "full", "two_stage"),
    train_keys=None,  # optional subset of exp_keys
    force_retrain=False,
    models_dir="models",
    prefix="lstm_TL",
):
    models = {}
    infos = {}

    train_keys_set = set(train_keys) if train_keys else None

    for exp_key in sorted(tl_assets_by_key.keys()):
        if train_keys_set is not None and exp_key not in train_keys_set:
            continue

        assets = tl_assets_by_key[exp_key]
        if assets is None or assets.get("ds_finetune", None) is None:
            logging.warning(f"Skipping {exp_key}: missing finetune dataset.")
            continue

        for strat in strategies:
            run_key = f"{exp_key}__{strat}"
            logging.info(f"\n=== FINETUNE {run_key} ===")

            model, path, info = finetune_or_load_one_TL(
                cfg=cfg,
                exp_key=exp_key,
                tl_assets_for_key=assets,
                best_params=best_params,
                device=device,
                pretrained_ckpt_path=pretrained_ckpt_path,
                models_dir=models_dir,
                prefix=prefix,
                strategy=strat,
                force_retrain=force_retrain,
            )

            models[run_key] = model
            infos[run_key] = {"model_path": path, **(info or {})}

    return models, infos

In [None]:
def train_or_load_CH_baseline(
    cfg,
    tl_assets: dict,  # the whole dict returned by build_transfer_learning_assets
    default_params: dict,
    device,
    models_dir="models",
    prefix="lstm_CH",
    key="BASELINE",
    train_flag=True,
    force_retrain=False,
    epochs=150,
    batch_size_train=64,
    batch_size_val=128,
):
    """
    Trains a CH-only model on ds_pretrain using CH scalers from ds_pretrain_scalers.
    Assumes all tl_assets share the same CH dataset + indices + scaler donor.
    """
    any_key = next(iter(tl_assets.keys()))
    assets0 = tl_assets[any_key]

    ds_train_pristine = assets0["ds_pretrain"]  # pristine CH dataset
    ds_ch_scalers = assets0[
        "ds_pretrain_scalers"]  # scaler donor (fitted on CH train split)
    train_idx = assets0["pretrain_train_idx"]
    val_idx = assets0["pretrain_val_idx"]

    current_date = datetime.now().strftime("%Y-%m-%d")
    os.makedirs(models_dir, exist_ok=True)
    model_path = os.path.join(models_dir, f"{prefix}_{key}_{current_date}.pt")

    # build model + loss
    model = mbm.models.LSTM_MB.build_model_from_params(cfg, default_params,
                                                       device)
    loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(default_params)

    # load if exists
    if (not train_flag) and os.path.exists(model_path):
        state = torch.load(model_path, map_location=device)
        model.load_state_dict(state)
        return model, model_path, None

    if train_flag and (not force_retrain) and os.path.exists(model_path):
        state = torch.load(model_path, map_location=device)
        model.load_state_dict(state)
        return model, model_path, None

    if (not train_flag) and (not os.path.exists(model_path)):
        raise FileNotFoundError(f"No CH checkpoint found: {model_path}")

    # loaders (DO NOT refit scalers; use ds_ch_scalers)
    mbm.utils.seed_all(cfg.seed)

    ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train_pristine)

    # Apply CH scalers + transform once
    ds_train_copy.set_scalers_from(ds_ch_scalers)
    ds_train_copy.transform_inplace()

    train_dl, val_dl = ds_train_copy.make_loaders(
        train_idx=train_idx,
        val_idx=val_idx,
        batch_size_train=batch_size_train,
        batch_size_val=batch_size_val,
        seed=cfg.seed,
        fit_and_transform=False,  # IMPORTANT: already transformed
        shuffle_train=True,
        use_weighted_sampler=True,
    )

    # fresh checkpoint
    if os.path.exists(model_path):
        os.remove(model_path)
        print(f"Deleted existing CH model file: {model_path}")

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=epochs,
        lr=default_params["lr"],
        weight_decay=default_params["weight_decay"],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_path,
        loss_fn=loss_fn,
    )

    plot_history_lstm(history)

    # load best
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state)

    return model, model_path, {"history": history, "best_val": best_val}

In [None]:
model_ch, ch_path, ch_info = train_or_load_CH_baseline(
    cfg=cfg,
    tl_assets=tl_assets,
    default_params=default_params,
    device=device,
    models_dir="models",
    prefix="lstm_CH",
    key="defaultparams",
    train_flag=True,
    force_retrain=True,  # set False after you have it once
    epochs=150,
)
print("CH baseline saved at:", ch_path)

In [None]:
models_tl, infos_tl = finetune_TL_models_all(
    cfg=cfg,
    tl_assets_by_key=tl_assets,
    best_params=default_params,
    device=device,
    pretrained_ckpt_path=ch_path,
    strategies=("safe", "full", "two_stage"),
    force_retrain=True,
    prefix="lstm_TL",
)

### Evaluate on test:

In [None]:
def make_test_loader_for_key_TL(cfg, tl_assets_for_key, batch_size=128):
    """
    TL-only test loader builder.

    Uses CH scalers from tl_assets_for_key["ds_pretrain_scalers"] and applies them to
    tl_assets_for_key["ds_test"] (holdout target region).

    Returns (ds_scalers, ds_test_copy, test_dl) so the caller signature matches the old one.
    """
    mbm.utils.seed_all(cfg.seed)

    ds_scalers = tl_assets_for_key[
        "ds_pretrain_scalers"]  # CH scaler donor (already fitted)
    ds_test = tl_assets_for_key["ds_test"]  # pristine holdout dataset

    if ds_test is None:
        raise ValueError("TL assets have ds_test=None (no holdout set).")

    # sanity: fitted scalers exist
    if (ds_scalers.month_mean is None) or (ds_scalers.static_mean
                                           is None) or (ds_scalers.y_mean
                                                        is None):
        raise ValueError(
            "ds_pretrain_scalers is missing fitted scalers. Did Option-2 caching run?"
        )

    # clone pristine test and transform using CH scalers
    ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_test)

    test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_test=ds_test_copy,
        ds_train=ds_scalers,
        seed=cfg.seed,
        batch_size=batch_size,
    )

    # return ds_scalers as first element to match old (ds_train_copy, ds_test_copy, test_dl)
    return ds_scalers, ds_test_copy, test_dl


def evaluate_one_model_TL(
        cfg,
        model,
        device,
        tl_assets_for_key,
        ax=None,
        ax_xlim=(-16, 9),
        ax_ylim=(-16, 9),
        title=None,
        legend_fontsize=16,
        batch_size=128,
):
    """
    TL-only evaluator (does not touch old within/xreg frameworks).

    - Builds a test loader with CH scalers via make_test_loader_for_key_TL
    - Uses model.evaluate_with_preds(device, test_dl, ds_test_copy) exactly like the original
    - Plots pred-vs-truth density exactly like the original
    """
    _ds_scalers, ds_test_copy, test_dl = make_test_loader_for_key_TL(
        cfg, tl_assets_for_key, batch_size=batch_size)

    test_metrics, test_df_preds = model.evaluate_with_preds(
        device, test_dl, ds_test_copy)

    scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                           target_col="target",
                                                           pred_col="pred")

    out = {
        "RMSE_annual":
        float(test_metrics.get("RMSE_annual", scores_annual["rmse"])),
        "RMSE_winter":
        float(test_metrics.get("RMSE_winter", scores_winter["rmse"])),
        "R2_annual":
        float(scores_annual["R2"]),
        "R2_winter":
        float(scores_winter["R2"]),
        "Bias_annual":
        float(scores_annual["Bias"]),
        "Bias_winter":
        float(scores_winter["Bias"]),
        "n_preds":
        int(len(test_df_preds)),
        "n_annual": (int(scores_annual.get("n", np.nan)) if isinstance(
            scores_annual, dict) else np.nan),
        "n_winter": (int(scores_winter.get("n", np.nan)) if isinstance(
            scores_winter, dict) else np.nan),
    }

    # Plot
    created_fig = None
    if ax is None:
        created_fig = plt.figure(figsize=(15, 10))
        ax = plt.subplot(1, 1, 1)

    pred_vs_truth_density(
        ax,
        test_df_preds,
        scores_annual,
        add_legend=False,
        palette=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
        ax_xlim=ax_xlim,
        ax_ylim=ax_ylim,
    )

    def _fmt(x):
        return ("NA" if
                (x is None or
                 (isinstance(x, float) and np.isnan(x))) else f"{x:.2f}")

    legend_NN = "\n".join([
        rf"$\mathrm{{RMSE_a}}={_fmt(scores_annual['rmse'])},\ \mathrm{{RMSE_w}}={_fmt(scores_winter['rmse'])}$",
        rf"$\mathrm{{R^2_a}}={_fmt(scores_annual['R2'])},\ \mathrm{{R^2_w}}={_fmt(scores_winter['R2'])}$",
        rf"$\mathrm{{Bias_a}}={_fmt(scores_annual['Bias'])},\ \mathrm{{Bias_w}}={_fmt(scores_winter['Bias'])}$",
    ])

    ax.text(
        0.02,
        0.98,
        legend_NN,
        transform=ax.transAxes,
        va="top",
        fontsize=legend_fontsize,
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.5),
    )

    if title:
        ax.set_title(title, fontsize=20)

    return out, test_df_preds, created_fig, ax

In [None]:
def _pick_tl_exp_key_for_region(tl_assets_by_key, region, split_name="5pct"):
    k = f"TL_CH_to_{region}_{split_name}"
    if k not in tl_assets_by_key:
        raise KeyError(f"Missing TL assets for {k}")
    return k


def evaluate_transfer_learning_grid(
        cfg,
        regions,  # e.g. ["NOR","FR","ISL","SJM"]
        models_xreg_by_region: dict,  # baseline: models_xreg["NOR"] etc.
        models_tl_by_key:
    dict,  # finetuned: models_tl["TL_CH_to_NOR_5pct__safe"] etc.
        tl_assets_by_key:
    dict,  # tl_assets["TL_CH_to_NOR_5pct"] -> has ds_test + ds_pretrain_scalers
        device,
        split_name="5pct",
        save_dir=None,
        fig_size_per_cell=(5.2, 4.2),
        ax_xlim=(-16, 9),
        ax_ylim=(-16, 9),
        legend_fontsize=11,
        batch_size_eval=128,  # NEW: lets you control eval batch size
):
    """
    Grid with rows=regions and cols=[no_ft, safe, full, two_stage].
    Uses tl_assets[exp_key]["ds_test"] as the holdout set for all columns in that row.

    IMPORTANT:
      - Evaluation uses CH scalers via assets_row["ds_pretrain_scalers"] (TL-only evaluator).
      - This function does not call the shared evaluate_one_model to avoid framework clashes.
    """
    strategies = ["no_ft", "safe", "full", "two_stage"]
    nrows = len(regions)
    ncols = len(strategies)

    if save_dir:
        save_abs = os.path.join(save_dir)
        os.makedirs(save_abs, exist_ok=True)
    else:
        save_abs = None

    figsize = (fig_size_per_cell[0] * ncols, fig_size_per_cell[1] * nrows)
    fig, axes = plt.subplots(nrows,
                             ncols,
                             figsize=figsize,
                             sharex=True,
                             sharey=True)
    axes = np.array(axes)

    rows = []
    preds = {}

    for r, region in enumerate(regions):
        exp_key = _pick_tl_exp_key_for_region(tl_assets_by_key,
                                              region,
                                              split_name=split_name)
        assets_row = tl_assets_by_key[exp_key]

        if assets_row is None or assets_row.get("ds_test", None) is None:
            logging.warning(
                f"Skipping region {region}: no ds_test in {exp_key}")
            for c in range(ncols):
                axes[r, c].axis("off")
            continue

        # TL-only evaluator needs these:
        if assets_row.get("ds_pretrain_scalers", None) is None:
            logging.warning(
                f"Skipping region {region}: missing ds_pretrain_scalers in {exp_key}"
            )
            for c in range(ncols):
                axes[r, c].axis("off")
            continue

        for c, strat in enumerate(strategies):
            ax = axes[r, c]

            if strat == "no_ft":
                model = models_xreg_by_region.get(region, None)
                title = f"{region}\nNo FT"
            else:
                model_key = f"{exp_key}__{strat}"
                model = models_tl_by_key.get(model_key, None)
                title = f"{region}\n{strat}"

            if model is None:
                ax.axis("off")
                logging.warning(
                    f"Missing model for region={region}, strategy={strat}")
                continue

            # ---- TL-only evaluation (uses CH scalers) ----
            metrics, df_preds, _fig_ind, _ = evaluate_one_model_TL(
                cfg=cfg,
                model=model,
                device=device,
                tl_assets_for_key=assets_row,  # <-- pass assets_row directly
                ax=ax,
                ax_xlim=ax_xlim,
                ax_ylim=ax_ylim,
                title=title,
                legend_fontsize=legend_fontsize,
                batch_size=batch_size_eval,
            )

            metrics.update({
                "region": region,
                "strategy": strat,
                "exp_key": exp_key,
                "split_name": split_name,
            })
            rows.append(metrics)
            preds[(region, strat)] = df_preds

            # remove legend if present
            leg = ax.get_legend()
            if leg is not None:
                leg.remove()

    # column labels
    col_titles = [
        "No fine-tune",
        "Heads-only FT (freeze LSTM)",
        "Full FT (unfreeze all)",
        "Two-stage FT",
    ]
    for rr in range(nrows):
        for cc in range(ncols):
            axes[rr, cc].set_title(f"{regions[rr]} - {col_titles[cc]}",
                                   fontsize=14)
            if cc == 0:
                axes[rr, cc].set_ylabel("Modeled PMB [m w.e.]", fontsize=12)
            else:
                axes[rr, cc].set_ylabel("")
            if rr == nrows - 1:
                axes[rr, cc].set_xlabel("Observed PMB [m w.e.]", fontsize=12)
            else:
                axes[rr, cc].set_xlabel("")

    fig.suptitle(
        f"Transfer learning evaluation (holdout test) — split={split_name}",
        fontsize=18)
    fig.tight_layout()

    if save_abs:
        out_png = os.path.join(save_abs, f"TL_grid_{split_name}.png")
        fig.savefig(out_png, dpi=200, bbox_inches="tight")

    df_metrics = pd.DataFrame(rows)
    if len(df_metrics) > 0:
        df_metrics = df_metrics.set_index(["region", "strategy"]).sort_index()

    return df_metrics, preds, fig

In [None]:
def load_one_xreg_model(
        cfg,
        region,
        best_params,
        device,
        models_dir="models",
        prefix="lstm_xreg_CH_to",
        date=None,  # if None → auto-detect latest
):
    """
    Loads one cross-regional CH→region model.
    """

    if date is None:
        # find latest file matching pattern
        pattern = f"{prefix}_{region}_"
        candidates = [
            f for f in os.listdir(models_dir)
            if f.startswith(pattern) and f.endswith(".pt")
        ]
        if len(candidates) == 0:
            raise FileNotFoundError(f"No checkpoint found for region {region}")

        candidates = sorted(candidates)  # last = latest by name
        filename = candidates[-1]
    else:
        filename = f"{prefix}_{region}_{date}.pt"

    path = os.path.join(models_dir, filename)

    # rebuild model
    model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params,
                                                       device)

    state = torch.load(path, map_location=device)
    model.load_state_dict(state)

    return model, path


def load_xreg_models_all(
    cfg,
    regions,
    best_params,
    device,
    models_dir="models",
    prefix="lstm_xreg_CH_to",
    date=None,
):
    models = {}
    paths = {}

    for region in regions:
        try:
            model, path = load_one_xreg_model(
                cfg=cfg,
                region=region,
                best_params=best_params,
                device=device,
                models_dir=models_dir,
                prefix=prefix,
                date=date,
            )
            models[region] = model
            paths[region] = path
            print(f"Loaded CH→{region} from {path}")

        except FileNotFoundError as e:
            print(f"Skipping {region}: {e}")
            models[region] = None
            paths[region] = None

    return models, paths

In [None]:
regions = ["FR", "NOR", "ISL"]  # pick any 4 you have models for

models_xreg, paths_xreg = load_xreg_models_all(
    cfg=cfg,
    regions=regions,
    best_params=default_params,
    device=device,
    models_dir="models",
    prefix="lstm_xreg_CH_to",
    date=None,  # auto-detect latest
)

In [None]:
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=regions,
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_tl,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="5pct",  # or "50pct"
    save_dir="figures/eval_TL",
)

In [None]:
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=["NOR", "ISL"],  # only regions with 50pct splits
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_tl,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="50pct",  # or "50pct"
    save_dir="figures/eval_TL",
)