In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import warnings
import massbalancemachine as mbm
import logging
from datetime import datetime
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *
from regions.TF_Europe.scripts.utils import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.EuropeTFConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')


## Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"].head()

## Monthly datasets:

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

vois_climate = [
    "t2m",
    "tp",
    "slhf",
    "sshf",
    "ssrd",
    "fal",
    "str",
]

vois_topographical = ["aspect", "slope", "svf"]

In [None]:
# load all stake dfs
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# prepare monthlies (recompute or load)
res_xreg, split_info = prepare_monthly_df_xreg_CH_to_EU(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,  # load if already computed
)

df_train = res_xreg["df_train"]
df_test = res_xreg["df_test"]

print("Train glaciers (CH):", len(split_info["train_glaciers"]))
print("Test glaciers (non-CH):", len(split_info["test_glaciers"]))
print("Train rows:", len(df_train), "Test rows:", len(df_test))

## Experiment design:

In [None]:
ISL_RID = "ISL"  # or maybe "06" depending on your setup

df_isl = df_test[df_test["SOURCE_CODE"] == ISL_RID].copy()

print("ISL rows:", len(df_isl))
print("ISL glaciers:", df_isl["GLACIER"].nunique())
print("ISL year range:", df_isl["YEAR"].min(), "-", df_isl["YEAR"].max())

### Fixed glacier hold-out split (spatial generalization):

In [None]:
import numpy as np
import pandas as pd


def glacier_level_features(
    df_isl: pd.DataFrame,
    glacier_col="GLACIER",
    year_col="YEAR",
):
    d = df_isl.copy()

    # --- handle circular aspect properly ---
    aspect_deg = d["aspect"].astype(float) % 360.0
    aspect_rad = np.deg2rad(aspect_deg)

    d["_asp_sin"] = np.sin(aspect_rad)
    d["_asp_cos"] = np.cos(aspect_rad)

    # glacier-level summaries
    g = d.groupby(glacier_col).agg(
        nrows=(glacier_col, "size"),
        nyears=(year_col, pd.Series.nunique),
        slope_mean=("slope", "mean"),
        slope_std=("slope", "std"),
        svf_mean=("svf", "mean"),
        asp_sin_mean=("_asp_sin", "mean"),
        asp_cos_mean=("_asp_cos", "mean"),
    )

    return g.reset_index()

In [None]:
gfeat = glacier_level_features(df_isl)
print("Number of glaciers:", len(gfeat))
gfeat.head()

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans


def holdout_split_cluster_stratified(
        df_isl,
        holdout_frac=0.30,
        seed=cfg.seed,
        n_clusters=6,  # start conservative; increase if many glaciers
):
    gfeat = glacier_level_features(df_isl)

    # feature space for clustering
    feat_cols = [
        "slope_mean",
        "slope_std",
        "svf_mean",
        "asp_sin_mean",
        "asp_cos_mean",
        "nyears",  # optional but stabilizes split
        "nrows",  # optional but stabilizes split
    ]

    X = gfeat[feat_cols].astype(float).fillna(gfeat[feat_cols].median())

    scaler = StandardScaler()
    Xs = scaler.fit_transform(X)

    km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10)
    gfeat["cluster"] = km.fit_predict(Xs)

    rng = np.random.default_rng(seed)
    holdout = []

    # stratified sampling per cluster
    for c, sub in gfeat.groupby("cluster"):
        gls = sub["GLACIER"].to_numpy()
        rng.shuffle(gls)
        k = int(np.ceil(holdout_frac * len(gls)))
        holdout.extend(gls[:k])

    holdout_glaciers = set(holdout)
    pool_glaciers = set(gfeat["GLACIER"]) - holdout_glaciers

    df_holdout = df_isl[df_isl["GLACIER"].isin(holdout_glaciers)].copy()
    df_pool = df_isl[df_isl["GLACIER"].isin(pool_glaciers)].copy()

    summary = {
        "holdout_frac": holdout_frac,
        "n_clusters": n_clusters,
        "n_glaciers_total": int(gfeat.shape[0]),
        "n_glaciers_holdout": int(len(holdout_glaciers)),
        "n_glaciers_pool": int(len(pool_glaciers)),
        "rows_holdout": int(len(df_holdout)),
        "rows_pool": int(len(df_pool)),
    }

    return df_pool, df_holdout, holdout_glaciers, pool_glaciers, gfeat, summary

In [None]:
df_isl_pool, df_isl_holdout, holdout_glaciers, pool_glaciers, gfeat, split_summary = (
    holdout_split_cluster_stratified(df_isl,
                                     holdout_frac=0.30,
                                     seed=cfg.seed,
                                     n_clusters=6))

print(split_summary)

In [None]:
g_hold = gfeat[gfeat["GLACIER"].isin(holdout_glaciers)]
g_pool = gfeat[gfeat["GLACIER"].isin(pool_glaciers)]

for col in ["slope_mean", "slope_std", "svf_mean", "nyears", "nrows"]:
    print(col)
    print("  pool   :", g_pool[col].quantile([0.1, 0.5, 0.9]).to_dict())
    print("  holdout:", g_hold[col].quantile([0.1, 0.5, 0.9]).to_dict())

In [None]:
print("Cluster counts - pool")
print(g_pool["cluster"].value_counts().sort_index())

print("\nCluster counts - holdout")
print(g_hold["cluster"].value_counts().sort_index())

In [None]:
ft_glaciers_by_split = {
    "spatial": holdout_glaciers,
}

data_ISL, glacier_outline_rgi, glacier_info_by_split = build_region_glacier_info_for_splits(
    cfg,
    rgi_region_id="06",
    outline_shp_path=cfg.dataPath +
    "RGI_v6/RGI_06_Iceland/06_rgi60_Iceland.shp",
    ft_glaciers_by_split=ft_glaciers_by_split,
    split_names=["spatial"],
    ft_label_col="Pool/Hold-out glacier",
    ft_label_ft="Pool",
    ft_label_holdout="Hold-out",
)

glacier_df_ISL_5pct = glacier_info_by_split["spatial"]

cmap_for_train = cm.batlow
train_color = "#1f4e79"
# requires your helper
colors = get_cmap_hex(cmap_for_train, 10)  # noqa: F821
train_color = colors[0]

palette = {"Hold-out": train_color, "Pool": "#b2182b"}

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_ISL_5pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Iceland",
    extent=(-25, -11, 62, 68),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="Pool/Hold-out glacier")

### Monitoring subsamples:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# res_xreg is the ONE dict from your cross-regional monthly prep
figs_by_code = plot_tsne_overlap_xreg_from_single_res(
    res_xreg=res_xreg,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=False,  # or True if you want *_aug
    n_iter=1000,
    only_codes=["ISL"],  # optional
)

In [None]:
def sample_monitoring_subset_from_pool(
    df_pool: pd.DataFrame,
    G: int,
    Y: int,
    M: int,
    seed: int = 0,
    glacier_pick_method:
    str = "random",  # "random" / "small_first" / "large_first" / "shuffle"
    min_rows_per_glacier: int = 1,
):
    rng = np.random.default_rng(seed)

    counts = df_pool.groupby("GLACIER").size().sort_values(ascending=False)
    counts = counts[counts >= min_rows_per_glacier]
    glaciers = counts.index.to_numpy()
    if G > len(glaciers):
        raise ValueError(f"G={G} > available pool glaciers ({len(glaciers)})")

    if glacier_pick_method == "random":
        chosen = rng.choice(glaciers, size=G, replace=False)
    elif glacier_pick_method == "small_first":
        chosen = counts.sort_values(ascending=True).index[:G].to_numpy()
    elif glacier_pick_method == "large_first":
        chosen = counts.sort_values(ascending=False).index[:G].to_numpy()
    elif glacier_pick_method == "shuffle":
        idx = glaciers.copy()
        rng.shuffle(idx)
        chosen = idx[:G]
    else:
        raise ValueError(
            f"Unknown glacier_pick_method='{glacier_pick_method}'")

    chosen = set(chosen)
    df_g = df_pool[df_pool["GLACIER"].isin(chosen)].copy()

    # earliest contiguous Y years per glacier
    keep_parts = []
    for gid, dfgid in df_g.groupby("GLACIER"):
        years = np.array(sorted(dfgid["YEAR"].unique()))
        y_keep = years[:min(Y, len(years))]
        keep_parts.append(dfgid[dfgid["YEAR"].isin(y_keep)])
    df_y = pd.concat(keep_parts, ignore_index=True)

    # sample up to M rows per glacier-year
    df_y["GLACIER_YEAR"] = df_y["GLACIER"].astype(
        str) + "_" + df_y["YEAR"].astype(int).astype(str)

    sampled = []
    for gy, dfgy in df_y.groupby("GLACIER_YEAR"):
        if len(dfgy) <= M:
            sampled.append(dfgy)
        else:
            rs = int(rng.integers(1, 1_000_000))
            sampled.append(dfgy.sample(n=M, random_state=rs))

    df_ft = pd.concat(sampled, ignore_index=True)
    return df_ft, chosen

In [None]:
def make_res_transfer_learning_custom(
    res_xreg: dict,
    target_code: str,
    df_ft: pd.DataFrame,
    holdout_glaciers: set,
    source_col="SOURCE_CODE",
):
    """
    Custom TL slicing:
      - pretrain: CH from res_xreg
      - finetune: provided df_ft (+ its _aug subset)
      - test: fixed holdout glaciers (+ its _aug subset)
    """
    res_pretrain = {
        "df_train": res_xreg["df_train"],
        "df_train_aug": res_xreg["df_train_aug"],
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    df_t_all = res_xreg["df_test"]
    df_t_all_aug = res_xreg["df_test_aug"]

    df_target = df_t_all.loc[df_t_all[source_col] == target_code].copy()
    df_target_aug = df_t_all_aug.loc[df_t_all_aug[source_col] ==
                                     target_code].copy()

    # finetune aug: match the same keys as df_ft (GLACIER,YEAR,ID,PERIOD)
    key_cols = ["GLACIER", "YEAR", "ID", "PERIOD"]
    ft_keys = df_ft[key_cols].copy()
    ft_keys["PERIOD"] = ft_keys["PERIOD"].astype(str).str.strip().str.lower()

    df_target_aug2 = df_target_aug.copy()
    df_target_aug2["PERIOD"] = df_target_aug2["PERIOD"].astype(
        str).str.strip().str.lower()

    df_ft_aug = df_target_aug2.merge(ft_keys.drop_duplicates(),
                                     on=key_cols,
                                     how="inner")

    # holdout = fixed glaciers
    df_hold = df_target[df_target["GLACIER"].isin(holdout_glaciers)].copy()
    df_hold_aug = df_target_aug2[df_target_aug2["GLACIER"].isin(
        holdout_glaciers)].copy()

    res_ft = {
        "df_train": df_ft,
        "df_train_aug": df_ft_aug,
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }
    res_test = {
        "df_test": df_hold,
        "df_test_aug": df_hold_aug,
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }
    return res_pretrain, res_ft, res_test

In [None]:
def build_static_tl_assets_CH_and_holdout(
    cfg,
    res_xreg,
    target_code: str,  # "ISL"
    holdout_glaciers: set,  # fixed glacier IDs
    MONTHLY_COLS,
    STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL_budget",
    force_recompute=False,
    val_ratio=0.2,
    key_train="TL_CH_TRAIN",
    key_holdout=None,  # if None -> auto name
    show_progress=True,
):
    """
    Builds (or loads) assets that are constant across all (G,Y,M,seed) experiments:
      - CH pretrain dataset + split + scaler donor
      - fixed target holdout dataset (evaluation-only)
    """
    if key_holdout is None:
        key_holdout = f"TL_CH_to_{target_code}_HOLDOUT_FIXED"

    # ---- CH pretrain datasets + scaler donor
    res_train = {
        "df_train": res_xreg["df_train"],
        "df_train_aug": res_xreg["df_train_aug"],
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    ds_ch, train_idx, val_idx, ds_ch_scalers = build_or_load_lstm_train_only(
        cfg=cfg,
        key_train=key_train,
        res_train=res_train,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        val_ratio=val_ratio,
        cache_dir=cache_dir,
        force_recompute=force_recompute,
        show_progress=show_progress)

    ch_source_codes = build_source_codes_for_dataset(ds_ch,
                                                     res_xreg["df_train_aug"],
                                                     source_col="SOURCE_CODE")

    # ---- fixed holdout df (target region)
    df_target = res_xreg["df_test"].loc[res_xreg["df_test"]["SOURCE_CODE"] ==
                                        target_code].copy()
    df_target_aug = res_xreg["df_test_aug"].loc[
        res_xreg["df_test_aug"]["SOURCE_CODE"] == target_code].copy()

    df_hold = df_target[df_target["GLACIER"].isin(holdout_glaciers)].copy()
    df_hold_aug = df_target_aug[df_target_aug["GLACIER"].isin(
        holdout_glaciers)].copy()

    if len(df_hold) == 0:
        raise ValueError(
            f"{target_code}: fixed holdout is empty. Check holdout_glaciers.")

    ds_holdout = build_or_load_lstm_dataset_only(
        cfg=cfg,
        key=key_holdout,
        df_loss=df_hold,
        df_full=df_hold_aug,
        months_head_pad=res_xreg["months_head_pad"],
        months_tail_pad=res_xreg["months_tail_pad"],
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        cache_dir=cache_dir,
        force_recompute=force_recompute,
        kind="test",
        show_progress=show_progress)

    holdout_source_codes = build_source_codes_for_dataset(
        ds_holdout, df_hold_aug, source_col="SOURCE_CODE")

    static_assets = {
        "ds_pretrain": ds_ch,
        "ds_pretrain_scalers": ds_ch_scalers,
        "pretrain_train_idx": train_idx,
        "pretrain_val_idx": val_idx,
        "pretrain_source_codes": ch_source_codes,
        "ds_test": ds_holdout,
        "test_source_codes": holdout_source_codes,
        "target_code": target_code,
        "cache_keys": {
            "pretrain": key_train,
            "test": key_holdout,
        },
    }
    return static_assets

In [None]:
def build_budget_assets_finetune_only(
    cfg,
    res_xreg,
    static_assets: dict,
    df_ft: pd.DataFrame,
    exp_key: str,
    MONTHLY_COLS,
    STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL_budget",
    force_recompute=False,
    val_ratio=0.2,
    show_progress=True,
):
    """
    Builds the only thing that varies per experiment: the finetune dataset + split.
    Then combines with static_assets into the final assets[exp_key] dict.
    """
    target_code = static_assets["target_code"]

    # target aug for extracting df_ft_aug
    df_target_aug = res_xreg["df_test_aug"].loc[
        res_xreg["df_test_aug"]["SOURCE_CODE"] == target_code].copy()
    df_target_aug["PERIOD"] = df_target_aug["PERIOD"].astype(
        str).str.strip().str.lower()

    # match aug rows to df_ft keys
    key_cols = ["GLACIER", "YEAR", "ID", "PERIOD"]
    ft_keys = df_ft[key_cols].copy()
    ft_keys["PERIOD"] = ft_keys["PERIOD"].astype(str).str.strip().str.lower()

    df_ft_aug = df_target_aug.merge(ft_keys.drop_duplicates(),
                                    on=key_cols,
                                    how="inner")

    if len(df_ft) == 0 or len(df_ft_aug) == 0:
        raise ValueError(f"{exp_key}: finetune df or aug df is empty.")

    # build finetune dataset (pristine)
    ft_cache_key = f"{exp_key}_FT"
    ds_ft = build_or_load_lstm_dataset_only(
        cfg=cfg,
        key=ft_cache_key,
        df_loss=df_ft,
        df_full=df_ft_aug,
        months_head_pad=res_xreg["months_head_pad"],
        months_tail_pad=res_xreg["months_tail_pad"],
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        cache_dir=cache_dir,
        force_recompute=force_recompute,
        kind="ft",
        show_progress=show_progress)

    ft_train_idx, ft_val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_ft), val_ratio=val_ratio, seed=cfg.seed)

    ft_source_codes = build_source_codes_for_dataset(ds_ft,
                                                     df_ft_aug,
                                                     source_col="SOURCE_CODE")

    # domain vocab: CH + FT + HOLDOUT
    domain_vocab = sorted(
        set(static_assets["pretrain_source_codes"])
        | set(ft_source_codes)
        | set(static_assets["test_source_codes"]))

    # assemble final experiment assets (same shape as before)
    assets = {
        exp_key: {
            **static_assets,
            "ds_finetune": ds_ft,
            "finetune_train_idx": ft_train_idx,
            "finetune_val_idx": ft_val_idx,
            "ft_source_codes": ft_source_codes,
            "domain_vocab": domain_vocab,
            "split_name": exp_key,  # optional
            "cache_keys": {
                **static_assets["cache_keys"],
                "finetune": ft_cache_key,
            },
        }
    }
    return assets

#### Build static assets once (CH + fixed ISL holdout)

In [None]:
static_assets = build_static_tl_assets_CH_and_holdout(
    cfg=cfg,
    res_xreg=res_xreg,
    target_code="ISL",
    holdout_glaciers=holdout_glaciers,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL_budget",
    force_recompute=False,
    show_progress=True)

#### Define the pools for experiments:

In [None]:
df_isl_pool_rows = df_isl[df_isl["GLACIER"].isin(pool_glaciers)]

In [None]:
df_pool = df_isl_pool_rows.copy()

G_max = df_pool["GLACIER"].nunique()
Y_max = df_pool.groupby(
    "GLACIER")["YEAR"].nunique().max()  # max years available on any glacier

# How many rows per glacier-year exist? (this sets an upper bound for M)
rows_per_gy = df_pool.groupby(["GLACIER", "YEAR"]).size()
M_p50 = int(rows_per_gy.median())
M_p90 = int(rows_per_gy.quantile(0.90))
M_max = int(rows_per_gy.max())

print("POOL CAPACITY")
print("G_max:", G_max)
print("Y_max (max years on a glacier):", Y_max)
print("Rows per glacier-year: median", M_p50, "| p90", M_p90, "| max", M_max)

In [None]:
G_set = [1, 2, 3, 5, 8, 13, 21, 35]
Y_set = [1, 2, 3, 5, 8, 13, 21, 37]
M_set = [4, 8, 16, 32, 64, 128, 200, 300]

G0, Y0, M0 = 8, 8, 64

# Tier A: sweeps
tierA = ([dict(G=g, Y=Y0, M=M0)
          for g in G_set] + [dict(G=G0, Y=y, M=M0) for y in Y_set] +
         [dict(G=G0, Y=Y0, M=m) for m in M_set])

# Tier B: corners/near-corners
tierB = [
    dict(G=1, Y=1, M=4),
    dict(G=35, Y=37, M=300),
    dict(G=35, Y=37, M=4),
    dict(G=35, Y=1, M=300),
    dict(G=1, Y=37, M=300),
    dict(G=35, Y=1, M=4),
    dict(G=1, Y=37, M=4),
    dict(G=1, Y=1, M=300),
]

# Tier C: interaction sampling
import numpy as np


def sample_tierC(n=20, seed=cfg.seed):
    rng = np.random.default_rng(seed)
    budgets = []
    for _ in range(n):
        budgets.append(
            dict(
                G=int(rng.choice(G_set)),
                Y=int(rng.choice(Y_set)),
                M=int(rng.choice(M_set)),
            ))
    # dedupe but keep order
    seen = set()
    out = []
    for b in budgets:
        t = (b["G"], b["Y"], b["M"])
        if t not in seen:
            seen.add(t)
            out.append(b)
    return out


tierC = sample_tierC(n=20, seed=cfg.seed)


# final list (dedupe)
def dedupe_budgets(lst):
    seen = set()
    out = []
    for b in lst:
        t = (b["G"], b["Y"], b["M"])
        if t not in seen:
            out.append(b)
            seen.add(t)
    return out


BUDGETS = dedupe_budgets(tierA + tierB + tierC)
print("Total budget points:", len(BUDGETS))

In [None]:
from tqdm.auto import tqdm

SEEDS = [10, 20, 30, 40, 50]  # R=5; later: add more
# Build list of tasks first
TASKS = [(b, seed) for b in BUDGETS for seed in SEEDS]

print("Total experiments to build:", len(TASKS))

assets_all = {}

for b, seed in tqdm(TASKS, desc="Building LSTM assets"):

    df_ft, chosen = sample_monitoring_subset_from_pool(
        df_pool=df_isl_pool_rows,
        G=b["G"],
        Y=b["Y"],
        M=b["M"],
        seed=seed,
        glacier_pick_method="random",
    )

    exp_key = f"TL_CH_to_ISL_G{b['G']}_Y{b['Y']}_M{b['M']}_seed{seed}"

    assets_one = build_budget_assets_finetune_only(
        cfg=cfg,
        res_xreg=res_xreg,
        static_assets=static_assets,
        df_ft=df_ft,
        exp_key=exp_key,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        cache_dir="logs/LSTM_cache_TL_budget",
        force_recompute=False,
        val_ratio=0.2,
        show_progress=False)

    assets_all.update(assets_one)

print("Total experiments built:", len(assets_all))

In [None]:
# # Sanity check:
# for k, v in assets_all.items():
#     print("\n", "=" * 60)
#     print("Experiment:", k)
#     print("Available keys:", list(v.keys()))

In [None]:
# # Sanity check:
# for exp_key, assets in assets_all.items():
#     ft_unique = set(assets["ft_source_codes"])
#     test_unique = set(
#         assets["test_source_codes"]) if assets["test_source_codes"] else set()
#     print(f"{exp_key} | FT domains: {ft_unique} | TEST domains: {test_unique}")

### LSTM CH Baseline:

In [None]:
log_path_gs_results = {
    "ISL": 'logs/GS_results/lstm_param_search_progress_OOS_ISL_2026-02-11.csv',
    "NOR": 'logs/GS_results/lstm_param_search_progress_OOS_NOR_2026-02-09.csv',
    "FR": 'logs/GS_results/lstm_param_search_progress_OOS_FR_2026-02-06.csv',
    "CH": 'logs/GS_results/lstm_param_search_progress_CH_2026-02-18.csv',
}

default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

params_by_key = build_lstm_params_by_key(
    default_params=default_params,
    log_path_gs_results=log_path_gs_results,
    RGI_REGIONS=RGI_REGIONS,
)

tl_assets_static = {"STATIC": static_assets}
model_ch, ch_path, ch_info = train_or_load_CH_baseline(
    cfg=cfg,
    tl_assets=tl_assets_static,
    default_params=params_by_key["11_CH"],
    device=device,
    models_dir="models/ISL_experiment",
    prefix="lstm_CH",
    key="BASELINE",
    train_flag=True,  # or False to only load
    force_retrain=False,
    epochs=150,
    batch_size_train=64,
    batch_size_val=128,
    verbose=False,
)

### Compute E_ZERO:
E_ZERO = error of the CH baseline model evaluated on the fixed ISL holdout dataset (unseen glaciers), with no finetuning

In [None]:
# Make a tl_assets dict with one key (because your codebase often expects dict-of-keys)
tl_assets_zero = {"TL_CH_to_ISL_ZERO": static_assets}

fig, ax = plt.subplots(1, 1, figsize=(6, 6))

metrics_zero, df_preds_zero, _, _ = evaluate_one_model_TL(
    cfg=cfg,
    model=model_ch,  # <-- CH baseline model
    device=device,
    tl_assets_for_key=tl_assets_zero["TL_CH_to_ISL_ZERO"],
    ax=ax,
    title="E_ZERO: CH baseline on ISL holdout",
    batch_size=128,
    domain_vocab=tl_assets_zero["TL_CH_to_ISL_ZERO"].get("domain_vocab", None),
)

plt.show()

E_ZERO = metrics_zero["RMSE_annual"]
print("E_ZERO (RMSE_annual):", E_ZERO)
print(metrics_zero)

### E_TL:

#### Train adapter-only models for experiments:

In [None]:
models_tl, infos_tl = finetune_TL_models_all(
    cfg=cfg,
    tl_assets_by_key=assets_all,
    best_params=params_by_key["11_CH"],
    device=device,
    pretrained_ckpt_path=ch_path,
    strategies=("adapter", ),
    force_retrain=False,
    models_dir="models/ISL_experiment/",
    prefix="lstm_TL",
    verbose=False,
    best_by_region=None,
    date=None,  # optional: to load old dates
)

#### Evaluate E_TL:
For each budget point on the same holdout

In [None]:
rows = []
for exp_key in tqdm(sorted(assets_all.keys()), desc="Evaluating TL models"):
    run_key = f"{exp_key}__adapter"
    model = models_tl.get(run_key, None)
    if model is None:
        # checkpoint might not exist / training skipped
        continue

    assets = assets_all[exp_key]

    metrics, df_preds, _, _ = evaluate_one_model_TL(cfg=cfg,
                                                    model=model,
                                                    device=device,
                                                    tl_assets_for_key=assets,
                                                    ax=None,
                                                    title=None,
                                                    batch_size=128,
                                                    domain_vocab=assets.get(
                                                        "domain_vocab", None),
                                                    show_plot=False)

    metrics["exp_key"] = exp_key
    rows.append(metrics)

df_etl = pd.DataFrame(rows).set_index("exp_key").sort_index()
df_etl["E_ZERO_RMSE_annual"] = E_ZERO
df_etl["Delta_vs_ZERO"] = df_etl["RMSE_annual"] - E_ZERO

display(df_etl[[
    "RMSE_annual", "R2_annual", "Bias_annual", "n_annual",
    "E_ZERO_RMSE_annual", "Delta_vs_ZERO"
]])

In [None]:
_re_budget = re.compile(
    r"_G(?P<G>\d+)_Y(?P<Y>\d+)_M(?P<M>\d+)_seed(?P<seed>\d+)")


def parse_budget(s: str):
    s = str(s)
    m = _re_budget.search(s)  # <-- search, not match
    if not m:
        return {"G": np.nan, "Y": np.nan, "M": np.nan, "seed": np.nan}
    return {k: int(v) for k, v in m.groupdict().items()}


df_etl2 = df_etl.reset_index().rename(columns={"index": "exp_key"})
meta = df_etl2["exp_key"].apply(parse_budget).apply(pd.Series)
df_etl2 = pd.concat([df_etl2, meta], axis=1)

# drop non-budget rows (if any)
df_etl2 = df_etl2.dropna(subset=["G", "Y", "M", "seed"])

agg = df_etl2.groupby(["G", "Y", "M"]).agg(
    RMSE_med=("RMSE_annual", "median"),
    RMSE_p10=("RMSE_annual", lambda x: np.quantile(x, 0.10)),
    RMSE_p90=("RMSE_annual", lambda x: np.quantile(x, 0.90)),
    n=("RMSE_annual", "size"),
).reset_index().sort_values(["G", "Y", "M"])

display(agg)

### Compute E_FULL: 
Adapter fine-tuned on all ISL pool data (everything that is not in the fixed holdout glaciers), evaluated on the same fixed holdout (ds_test).

In [None]:
exp_key_full = "TL_CH_to_ISL_FULLPOOL"

assets_full = build_budget_assets_finetune_only(
    cfg=cfg,
    res_xreg=res_xreg,
    static_assets=static_assets,
    df_ft=df_isl_pool_rows,  # <-- ALL pool data
    exp_key=exp_key_full,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL_budget",
    force_recompute=False,
    val_ratio=0.2,
)

# merge into your experiment dict (optional but convenient)
assets_all_plus = dict(assets_all)
assets_all_plus.update(assets_full)

print("FULL asset built:", exp_key_full)
print("Full finetune sequences:",
      len(assets_all_plus[exp_key_full]["ds_finetune"]))

In [None]:
models_full, infos_full = finetune_TL_models_all(
    cfg=cfg,
    tl_assets_by_key={exp_key_full:
                      assets_all_plus[exp_key_full]},  # only this one
    best_params=params_by_key["11_CH"],
    device=device,
    pretrained_ckpt_path=ch_path,
    strategies=("adapter", ),
    force_retrain=False,
    models_dir="models/ISL_experiment/",
    prefix="lstm_TL",
    verbose=False,
    best_by_region=None,
    date="fixed",
)

run_key_full = f"{exp_key_full}__adapter"
model_full = models_full[run_key_full]  # <-- this is the model object

metrics_full, df_preds_full, _, _ = evaluate_one_model_TL(
    cfg=cfg,
    model=model_full,
    device=device,
    tl_assets_for_key=assets_all_plus[exp_key_full],
    ax=None,
    title=None,
    batch_size=128,
    domain_vocab=assets_all_plus[exp_key_full].get("domain_vocab", None),
)

E_FULL = metrics_full["RMSE_annual"]
print("E_FULL (RMSE_annual):", E_FULL)
print(metrics_full)

### Compute E_SCRATCH:
“no-transfer” (from-scratch) baseline trained on the same small ISL monitoring subset and evaluated on the same fixed ISL holdout

In [None]:
_re_budget = re.compile(
    r"_G(?P<G>\d+)_Y(?P<Y>\d+)_M(?P<M>\d+)_seed(?P<seed>\d+)")


def parse_budget(exp_key: str):
    m = _re_budget.search(str(exp_key))  # works even with ..._FT_ft.joblib
    if not m:
        return {"G": np.nan, "Y": np.nan, "M": np.nan, "seed": np.nan}
    return {k: int(v) for k, v in m.groupdict().items()}


def within_assets_from_tl_assets(tl_assets_for_key: dict):
    return {
        "ds_train": tl_assets_for_key["ds_finetune"],  # finetune subset
        "ds_test": tl_assets_for_key["ds_test"],  # fixed holdout
        "train_idx": tl_assets_for_key["finetune_train_idx"],
        "val_idx": tl_assets_for_key["finetune_val_idx"],
    }


models_within = {}
infos_within = {}
E_SCRATCH_by_key = {}
rows = []

exp_keys = sorted(assets_all.keys())

pbar = tqdm(exp_keys, desc="Training+Evaluating E_SCRATCH", dynamic_ncols=True)

for exp_key in pbar:
    meta = parse_budget(exp_key)
    if meta:
        pbar.set_postfix(meta)

    w_assets = within_assets_from_tl_assets(assets_all[exp_key])

    model_w, path_w, info_w = train_or_load_one_within_region(
        cfg=cfg,
        key=exp_key,
        lstm_assets=w_assets,
        best_params=params_by_key["06_ISL"],
        device=device,
        models_dir="models/ISL_experiment/",
        prefix="lstm_within_ISL",
        train_flag=True,
        force_retrain=True,
        epochs=150,
        batch_size_train=64,
        batch_size_val=128,
        batch_size_test=128,
        verbose=False)

    models_within[exp_key] = model_w
    infos_within[exp_key] = {"model_path": path_w, **(info_w or {})}

    # ---- Evaluate ----
    met_w, df_w = model_w.evaluate_with_preds(
        device,
        info_w["test_dl"],
        info_w["ds_test"],
    )

    E_SCRATCH = float(met_w["RMSE_annual"])
    E_SCRATCH_by_key[exp_key] = E_SCRATCH

    rows.append({
        "exp_key": exp_key,
        "RMSE_SCRATCH": E_SCRATCH,
        **meta,
    })

df_scratch = pd.DataFrame(rows).set_index("exp_key").sort_index()
display(df_scratch)

In [None]:
# save to CSV (optional)
df_scratch.to_csv("results/ISL_experiments/ISL_experiment_E_SCRATCH.csv")

## Evaluate experiments:

### Build one unified results dataframe (with G/Y/M/seed parsed):

In [None]:
import re, numpy as np, pandas as pd

_re_budget = re.compile(
    r"_G(?P<G>\d+)_Y(?P<Y>\d+)_M(?P<M>\d+)_seed(?P<seed>\d+)")


def parse_budget(s: str):
    m = _re_budget.search(str(s))
    if not m:
        return {"G": np.nan, "Y": np.nan, "M": np.nan, "seed": np.nan}
    return {k: int(v) for k, v in m.groupdict().items()}


# unify indices
df = df_etl.copy()
df["RMSE_SCRATCH"] = df_scratch["RMSE_SCRATCH"]

meta = df.index.to_series().apply(parse_budget).apply(pd.Series)
df = pd.concat([meta, df], axis=1)

# drop any non-budget rows (if any)
df = df.dropna(subset=["G", "Y", "M", "seed"]).copy()

# add scalars + derived metrics
df["E_ZERO"] = float(E_ZERO)
df["E_FULL"] = float(E_FULL)

df["Effort"] = df["G"] * df["Y"] * df["M"]  # simple monitoring effort proxy
df["Recovery_TL"] = (df["E_ZERO"] - df["RMSE_annual"]) / (df["E_ZERO"] -
                                                          df["E_FULL"])
df["Transfer_gain"] = df["RMSE_SCRATCH"] - df["RMSE_annual"]  # >0 = TL better

df.head()

### Figure: RMSE vs monitoring effort (TL + Scratch + reference lines)

In [None]:
fig = plt.figure(figsize=(7.5, 5.5))
ax = plt.subplot(1, 1, 1)

ax.scatter(df["Effort"], df["RMSE_annual"], alpha=0.35, label="TL (adapter)")
ax.scatter(df["Effort"],
           df["RMSE_SCRATCH"],
           alpha=0.35,
           label="Scratch (within ISL)")

ax.axhline(E_ZERO, linestyle="--", label="E_ZERO (CH only)")
ax.axhline(E_FULL, linestyle="--", label="E_FULL (max monitoring)")

ax.set_xscale("log")
ax.set_xlabel("Monitoring effort proxy: G × Y × M (log)")
ax.set_ylabel("RMSE_annual on ISL holdout")
ax.set_title("Performance vs monitoring effort")
ax.legend()
plt.tight_layout()
plt.show()

### Figure: Transfer gain vs effort (proof transfer helps)

In [None]:
fig = plt.figure(figsize=(7.5, 5.0))
ax = plt.subplot(1, 1, 1)

ax.scatter(df["Effort"], df["Transfer_gain"], alpha=0.5)
ax.axhline(0.0, linestyle="--")

ax.set_xscale("log")
ax.set_xlabel("Effort proxy: G × Y × M (log)")
ax.set_ylabel("Transfer gain = RMSE_SCRATCH − RMSE_TL")
ax.set_title("Where transfer helps (positive = TL better)")
plt.tight_layout()
plt.show()

### Figure: Recovery vs effort (normalized 0–1 scale)

In [None]:
fig = plt.figure(figsize=(7.5, 5.0))
ax = plt.subplot(1, 1, 1)

ax.scatter(df["Effort"], df["Recovery_TL"], alpha=0.45)
ax.axhline(0.0, linestyle="--")
ax.axhline(0.5, linestyle="--")
ax.axhline(1.0, linestyle="--")

ax.set_xscale("log")
ax.set_xlabel("Effort proxy: G × Y × M (log)")
ax.set_ylabel("Recovery = (E_ZERO − E_TL)/(E_ZERO − E_FULL)")
ax.set_title("Recovery vs monitoring effort")
plt.tight_layout()
plt.show()

### Heatmaps: median Recovery over (G,Y) at fixed M

In [None]:
def heatmap_recovery(df, M_fixed):
    sub = df[df["M"] == M_fixed].copy()
    if sub.empty:
        print(f"No runs for M={M_fixed}")
        return

    g = sub.groupby(["G", "Y"])["Recovery_TL"].median().reset_index()
    mat = g.pivot(index="Y", columns="G",
                  values="Recovery_TL").sort_index().sort_index(axis=1)

    fig = plt.figure(figsize=(8.5, 5.5))
    ax = plt.subplot(1, 1, 1)
    im = ax.imshow(mat.values, aspect="auto", origin="lower")

    ax.set_title(f"Median Recovery (TL) — fixed M={M_fixed}")
    ax.set_xlabel("G (glaciers)")
    ax.set_ylabel("Y (years)")

    ax.set_xticks(np.arange(mat.shape[1]))
    ax.set_xticklabels(mat.columns.tolist(), rotation=45)
    ax.set_yticks(np.arange(mat.shape[0]))
    ax.set_yticklabels(mat.index.tolist())

    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Recovery")
    plt.tight_layout()
    plt.show()


for M_fixed in [64, 200]:
    heatmap_recovery(df, M_fixed)

### Table: “minimum budget” candidates for target recovery (Pareto-ish)

In [None]:
targets = [0.25, 0.50, 0.75]

agg = df.groupby(["G", "Y", "M"]).agg(
    Rec_med=("Recovery_TL", "median"),
    Rec_p10=("Recovery_TL", lambda x: np.quantile(x, 0.10)),
    Rec_p90=("Recovery_TL", lambda x: np.quantile(x, 0.90)),
    RMSE_med=("RMSE_annual", "median"),
    Effort=("Effort", "median"),
    n=("Recovery_TL", "size"),
).reset_index().sort_values("Effort")

for t in targets:
    ok = agg[agg["Rec_med"] >= t].copy()
    print(f"\n=== Smallest-effort budgets with median Recovery ≥ {t} ===")
    if ok.empty:
        print("None reached.")
    else:
        display(
            ok.head(10)[[
                "G", "Y", "M", "Effort", "Rec_med", "Rec_p10", "Rec_p90",
                "RMSE_med", "n"
            ]])

### Optional: Tier A sweep plots (which knob matters most?)

In [None]:
G0, Y0, M0 = 8, 8, 64


def plot_sweep(df, mode):
    if mode == "G":
        sub = df[(df["Y"] == Y0) & (df["M"] == M0)]
        x = "G"
        title = f"Sweep G (Y={Y0}, M={M0})"
    elif mode == "Y":
        sub = df[(df["G"] == G0) & (df["M"] == M0)]
        x = "Y"
        title = f"Sweep Y (G={G0}, M={M0})"
    elif mode == "M":
        sub = df[(df["G"] == G0) & (df["Y"] == Y0)]
        x = "M"
        title = f"Sweep M (G={G0}, Y={Y0})"
    else:
        raise ValueError

    if sub.empty:
        print("No data for", mode, "sweep")
        return

    g = sub.groupby(x)["RMSE_annual"].agg(
        med="median",
        p10=lambda s: np.quantile(s, 0.10),
        p90=lambda s: np.quantile(s, 0.90),
        n="size").reset_index().sort_values(x)

    fig = plt.figure(figsize=(6.8, 4.6))
    ax = plt.subplot(1, 1, 1)
    ax.plot(g[x], g["med"], marker="o")
    ax.fill_between(g[x], g["p10"], g["p90"], alpha=0.2)

    ax.axhline(E_ZERO, linestyle="--")
    ax.axhline(E_FULL, linestyle="--")
    ax.set_title(title)
    ax.set_xlabel(x)
    ax.set_ylabel("RMSE_annual (median ± p10–p90)")
    plt.tight_layout()
    plt.show()


for mode in ["G", "Y", "M"]:
    plot_sweep(df, mode)