In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import warnings
import massbalancemachine as mbm
import logging
from datetime import datetime
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *
from regions.TF_Europe.scripts.utils import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.EuropeTFConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')


# Cross-regional modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"].head()

### Monthly datasets:
Build monthly datasets for LSTM. 

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

vois_climate = [
    "t2m",
    "tp",
    "slhf",
    "sshf",
    "ssrd",
    "fal",
    "str",
]

vois_topographical = ["aspect", "slope", "svf"]

In [None]:
# load all stake dfs
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# prepare monthlies (recompute or load)
res_xreg, split_info = prepare_monthly_df_xreg_CH_to_EU(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,  # load if already computed
)

df_train = res_xreg["df_train"]
df_test = res_xreg["df_test"]

print("Train glaciers (CH):", len(split_info["train_glaciers"]))
print("Test glaciers (non-CH):", len(split_info["test_glaciers"]))
print("Train rows:", len(df_train), "Test rows:", len(df_test))

### Finetuning glaciers:

#### Automatic picking:

In [None]:
# SJM 5%: small-first greedy usually gives best control for a small fraction
SJM_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="SJM",
    target_frac=0.05,
    method="greedy_small_first",
    seed=cfg.seed,
)

# SJM 50%: large-first or small-first both work; I’d start with large-first
SJM_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="SJM",
    target_frac=0.50,
    method="greedy_large_first",
    seed=cfg.seed,
)

print("\nSJM_5pct glaciers:", SJM_5pct)
print("\nSJM_50pct glaciers:", SJM_50pct)

# ISL 5%: small-first greedy usually gives best control for a small fraction
ISL_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="ISL",
    target_frac=0.05,
    method="greedy_small_first",
    seed=cfg.seed,
)

# ISL 50%: large-first or small-first both work; I’d start with large-first
ISL_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="ISL",
    target_frac=0.50,
    method="greedy_large_first",
    seed=cfg.seed,
)

print("\nISL_5pct glaciers:", ISL_5pct)
print("\nISL_50pct glaciers:", ISL_50pct)

# FR 5%: small-first greedy usually gives best control for a small fraction
FR_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="FR",
    target_frac=0.05,
    method="greedy_small_first",
    seed=cfg.seed,
)

# FR 50%: large-first or small-first both work; I’d start with large-first
FR_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="FR",
    target_frac=0.50,
    method="greedy_large_first",
    seed=cfg.seed,
)

print("\nFR_5pct glaciers:", FR_5pct)
print("\nFR_50pct glaciers:", FR_50pct)

# IT_AT 5%: small-first greedy usually gives best control for a small fraction
IT_AT_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="IT_AT",
    target_frac=0.05,
    method="greedy_small_first",
    seed=cfg.seed,
)

# IT_AT 50%: large-first or small-first both work; I’d start with large-first
IT_AT_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="IT_AT",
    target_frac=0.50,
    method="greedy_large_first",
    seed=cfg.seed,
)

print("\nIT_AT_5pct glaciers:", IT_AT_5pct)
print("\nIT_AT_50pct glaciers:", IT_AT_50pct)

# NOR 5%: small-first greedy usually gives best control for a small fraction
NOR_5pct, sjm5_info, sjm_counts = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="NOR",
    target_frac=0.05,
    method="greedy_small_first",
    seed=cfg.seed,
)

# NOR 50%: large-first or small-first both work; I’d start with large-first
NOR_50pct, sjm50_info, _ = pick_glaciers_by_row_fraction(
    df_test=df_test,
    region_code="NOR",
    target_frac=0.50,
    method="greedy_large_first",
    seed=cfg.seed,
)

print("\nNOR_5pct glaciers:", NOR_5pct)
print("\nNOR_50pct glaciers:", NOR_50pct)

#### Final ft and hold-out glaciers:

In [None]:
# Norway
# 5% split
FT_5PCT_NOR = [
    'Moesevassbrea', 'Vetlefjordbreen', 'Juvfonne', 'Graasubreen',
    'Hellstugubreen', 'Storglombreen N', 'Blabreen', 'Ruklebreen',
    'Vestre Memurubreen', 'Cainhavarre', 'Bondhusbrea'
]

# 50% split
FT_50PCT_NOR = [
    'Nigardsbreen', 'Aalfotbreen', 'Engabreen', 'Storsteinsfjellbreen',
    'Cainhavarre'
]

# France
# 5% split
FT_5PCT_FR = ['Grands Montets', 'Sarennes', 'Talefre', 'Leschaux']

# 50% split
FT_50PCT_FR = ['Argentiere', 'Gebroulaz']

# IT-AT
FT_5PCT_IT_AT = [
    'CIARDONEY', 'CARESER CENTRALE', 'CAMPO SETT.', 'ZETTALUNITZ/MULLWITZ K.',
    'HALLSTAETTER G.', 'VENEDIGER K.', 'SURETTA MERIDIONALE', 'GOLDBERG K.',
    'CARESER OCCIDENTALE', 'GRAND ETRET', 'LUPO'
]

FT_50PCT_IT_AT = [
    'HINTEREIS F.', 'MALAVALLE (VEDR. DI) / UEBELTALF.',
    'LUNGA (VEDRETTA) / LANGENF.', 'RIES OCC. (VEDR. DI) / RIESERF. WESTL.'
]

# Iceland
# 5% split
FT_5PCT_ISL = [
    "Tungnaarjoekull", "Slettjoekull West",
    "Hagafellsjoekull East (Langjoekull S Dome)", "RGI60-06.00478",
    "Mulajoekull"
]

# 50% split
FT_50PCT_ISL = [
    'RGI60-06.00238', 'Bruarjoekull', 'Skeidararjoekull',
    'Thjorsarjoekull (Hofsjoekull E)', 'Sidujoekull/Skaftarjoekull',
    'Hagafellsjoekull West', 'RGI60-06.00305'
]

# Svalbard
# 15% split
FT_5PCT_SJM = ['WERENSKIOLDBREEN']

# 5% split
FT_50PCT_SJM = ['GROENFJORD E', 'WERENSKIOLDBREEN']

# Summary of splits for all regions
FT_GLACIERS = {
    "FR": {
        "5pct": FT_5PCT_FR,
        "50pct": FT_50PCT_FR
    },
    "IT_AT": {
        "5pct": FT_5PCT_IT_AT,
        "50pct": FT_50PCT_IT_AT
    },
    "NOR": {
        "5pct": FT_5PCT_NOR,
        "50pct": FT_50PCT_NOR
    },
    "ISL": {
        "5pct": FT_5PCT_ISL,
        "50pct": FT_50PCT_ISL
    },
    "SJM": {
        "5pct": FT_5PCT_SJM,
        "50pct": FT_50PCT_SJM
    }
}

In [None]:
df_row_check = verify_row_percentage(df_test, FT_GLACIERS)
df_row_check

In [None]:
# Sanity check
for reg in FT_GLACIERS.keys():
    gls = sorted(df_test.loc[df_test["SOURCE_CODE"] == reg,
                             "GLACIER"].unique())
    print(reg, "unique glaciers in df_test:", len(gls))

### Plot test/train glaciers:

##### Central European Alps:

In [None]:
FT_GL_CEU_5pct = FT_GLACIERS["FR"]["5pct"] + FT_GLACIERS["IT_AT"]["5pct"]
FT_GL_CEU_50pct = FT_GLACIERS["FR"]["50pct"] + FT_GLACIERS["IT_AT"]["50pct"]

ft_glaciers_by_split = {
    "5pct": FT_GL_CEU_5pct,
    "50pct": FT_GL_CEU_50pct,
}

data_CEU, glacier_outline_rgi, glacier_info_by_split = build_region_glacier_info_for_splits(
    cfg,
    rgi_region_id="11",
    outline_shp_path=cfg.dataPath +
    "RGI_v6/RGI_11_CentralEurope/11_rgi60_CentralEurope.shp",
    ft_glaciers_by_split=ft_glaciers_by_split,
    split_names=("5pct", "50pct"),
    ft_label_col="FT/Hold-out glacier",
)

glacier_df_CEU_5pct = glacier_info_by_split["5pct"]
glacier_df_CEU_50pct = glacier_info_by_split["50pct"]

# remove CH glaciers
glacier_df_CEU_5pct = glacier_df_CEU_5pct[~glacier_df_CEU_5pct["SOURCE_CODE"].
                                          isin(["CH"])]
glacier_df_CEU_50pct = glacier_df_CEU_50pct[
    ~glacier_df_CEU_50pct["SOURCE_CODE"].isin(["CH"])]

cmap_for_train = cm.batlow
train_color = "#1f4e79"
# requires your helper
colors = get_cmap_hex(cmap_for_train, 10)  # noqa: F821
train_color = colors[0]

palette = {"Hold-out": train_color, "FT": "#b2182b"}

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_CEU_5pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier measurement locations Central European Alps (5pct)",
    extent=(5.8, 15, 44, 48),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_CEU_50pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier measurement locations Central European Alps (50pct)",
    extent=(5.8, 15, 44, 48),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

##### Norway:

In [None]:
FT_GL_NOR_5pct = FT_GLACIERS["NOR"]["5pct"]
FT_GL_NOR_50pct = FT_GLACIERS["NOR"]["50pct"]

ft_glaciers_by_split = {
    "5pct": FT_GL_NOR_5pct,
    "50pct": FT_GL_NOR_50pct,
}

data_NOR, glacier_outline_rgi, glacier_info_by_split = build_region_glacier_info_for_splits(
    cfg,
    rgi_region_id="08",
    outline_shp_path=cfg.dataPath +
    "RGI_v6/RGI_08_Scandinavia/08_rgi60_Scandinavia.shp",
    ft_glaciers_by_split=ft_glaciers_by_split,
    split_names=("5pct", "50pct"),
    ft_label_col="FT/Hold-out glacier",
)

glacier_df_NOR_5pct = glacier_info_by_split["5pct"]
glacier_df_NOR_50pct = glacier_info_by_split["50pct"]

cmap_for_train = cm.batlow
train_color = "#1f4e79"
# requires your helper
colors = get_cmap_hex(cmap_for_train, 10)  # noqa: F821
train_color = colors[0]

palette = {"Hold-out": train_color, "FT": "#b2182b"}

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_NOR_5pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Norway (5pct)",
    extent=(4, 24, 57, 71),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier",
    legend_ncol=2)

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_NOR_50pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Norway (50pct)",
    extent=(4, 24, 57, 71),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier",
    legend_ncol=2)

##### Svalbard:

In [None]:
FT_GL_SJM_5pct = FT_GLACIERS["SJM"]["5pct"]
FT_GL_SJM_50pct = FT_GLACIERS["SJM"]["50pct"]

ft_glaciers_by_split = {
    "5pct": FT_GL_SJM_5pct,
    "50pct": FT_GL_SJM_50pct,
}

data_SJM, glacier_outline_rgi, glacier_info_by_split = build_region_glacier_info_for_splits(
    cfg,
    rgi_region_id="07",
    outline_shp_path=cfg.dataPath +
    "RGI_v6/RGI_07_Svalbard/07_rgi60_Svalbard.shp",
    ft_glaciers_by_split=ft_glaciers_by_split,
    split_names=("5pct", "50pct"),
    ft_label_col="FT/Hold-out glacier",
)

glacier_df_SJM_5pct = glacier_info_by_split["5pct"]
glacier_df_SJM_50pct = glacier_info_by_split["50pct"]

cmap_for_train = cm.batlow
train_color = "#1f4e79"
# requires your helper
colors = get_cmap_hex(cmap_for_train, 10)  # noqa: F821
train_color = colors[0]

palette = {"Hold-out": train_color, "FT": "#b2182b"}

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_SJM_5pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Svalbard (5pct)",
    extent=(5, 30, 76, 80),
    sizes=(100, 1000),
    size_legend_values=(30, 100, 400),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_SJM_50pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Svalbard (50pct)",
    extent=(5, 30, 76, 80),
    sizes=(100, 1000),
    size_legend_values=(30, 100, 400),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

##### Iceland:

In [None]:
FT_GL_ISL_5pct = FT_GLACIERS["ISL"]["5pct"]
FT_GL_ISL_50pct = FT_GLACIERS["ISL"]["50pct"]

ft_glaciers_by_split = {
    "5pct": FT_GL_ISL_5pct,
    "50pct": FT_GL_ISL_50pct,
}

data_ISL, glacier_outline_rgi, glacier_info_by_split = build_region_glacier_info_for_splits(
    cfg,
    rgi_region_id="06",
    outline_shp_path=cfg.dataPath +
    "RGI_v6/RGI_06_Iceland/06_rgi60_Iceland.shp",
    ft_glaciers_by_split=ft_glaciers_by_split,
    split_names=("5pct", "50pct"),
    ft_label_col="FT/Hold-out glacier",
)

glacier_df_ISL_5pct = glacier_info_by_split["5pct"]
glacier_df_ISL_50pct = glacier_info_by_split["50pct"]

cmap_for_train = cm.batlow
train_color = "#1f4e79"
# requires your helper
colors = get_cmap_hex(cmap_for_train, 10)  # noqa: F821
train_color = colors[0]

palette = {"Hold-out": train_color, "FT": "#b2182b"}

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_ISL_5pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Iceland (5pct)",
    extent=(-25, -11, 62, 68),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

fig, ax, glacier_info_plot, scaled_size_fn = plot_glacier_measurements_map(
    glacier_info=glacier_df_ISL_50pct,
    glacier_outline_rgi=glacier_outline_rgi,
    title="Glacier PMB location Iceland (50pct)",
    extent=(-25, -11, 62, 68),
    sizes=(100, 1500),
    size_legend_values=(30, 100, 1000),
    palette=palette,
    cmap_for_train=cm.batlow,  # optional, uses your get_cmap_hex if available
    split_col="FT/Hold-out glacier")

### Plot feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# # res_xreg is the ONE dict from your cross-regional monthly prep
# figs_by_code = plot_tsne_overlap_xreg_from_single_res(
#     res_xreg=res_xreg,
#     cfg=cfg,
#     STATIC_COLS=STATIC_COLS,
#     MONTHLY_COLS=MONTHLY_COLS,
#     group_col="SOURCE_CODE",
#     ch_code="CH",
#     use_aug=False,  # or True if you want *_aug
#     n_iter=1000,
#     # only_codes=["IT_AT"],  # optional
# )

In [None]:
# FEATURES = MONTHLY_COLS + STATIC_COLS + ["POINT_BALANCE"]

# figs_kde = plot_feature_kde_overlap_xreg_ch_vs_codes(
#     res_xreg=res_xreg,
#     cfg=cfg,
#     features=FEATURES,
#     group_col="SOURCE_CODE",
#     ch_code="CH",
#     use_aug=True,  # usually best for feature overlap
#     # only_codes=["IT_AT", "FR"],    # optional
#     output_dir="figures/xreg_kde",  # optional
# )


## LSTM model
### LSTM datasets:

In [None]:
tl_assets = build_transfer_learning_assets(
    cfg=cfg,
    res_xreg=res_xreg,
    FT_GLACIERS=FT_GLACIERS,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache_TL",
    force_recompute=False,
)

In [None]:
# Sanity check:
for k, v in tl_assets.items():
    print("\n", "=" * 60)
    print("Experiment:", k)
    print("Available keys:", list(v.keys()))

In [None]:
# Sanity check:
for exp_key, assets in tl_assets.items():
    ft_unique = set(assets["ft_source_codes"])
    test_unique = set(
        assets["test_source_codes"]) if assets["test_source_codes"] else set()
    print(f"{exp_key} | FT domains: {ft_unique} | TEST domains: {test_unique}")

### LSTM parameters:

In [None]:
log_path_gs_results = {
    "ISL": 'logs/GS_results/lstm_param_search_progress_OOS_ISL_2026-02-11.csv',
    "NOR": 'logs/GS_results/lstm_param_search_progress_OOS_NOR_2026-02-09.csv',
    "FR": 'logs/GS_results/lstm_param_search_progress_OOS_FR_2026-02-06.csv',
    "CH": 'logs/GS_results/lstm_param_search_progress_CH_2026-02-18.csv',
}

default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

params_by_key = build_lstm_params_by_key(
    default_params=default_params,
    log_path_gs_results=log_path_gs_results,
    RGI_REGIONS=RGI_REGIONS,
)

params_by_key.keys()

### Train or load CH baseline:

In [None]:
model_ch, ch_path, ch_info = train_or_load_CH_baseline(
    cfg=cfg,
    tl_assets=tl_assets,
    default_params=params_by_key["11_CH"],
    device=device,
    models_dir="models",
    prefix="lstm_CH",
    key="defaultparams",
    train_flag=False,
    force_retrain=False,  # set False after you have it once
    epochs=150,
)
print("CH baseline saved at:", ch_path)

### Grid search:

In [None]:
SPECIAL_REGIONS = {"NOR", "ISL", "SVAL"}  # adjust codes to match SOURCE_CODE
SPLIT = "5pct"

ADAPTER_GRID = [
    {
        "adapter_bottleneck": 16,
        "lr_adapter": 3e-5,
        "adapter_dropout": 0.0
    },
    {
        "adapter_bottleneck": 32,
        "lr_adapter": 1e-4,
        "adapter_dropout": 0.0
    },
    {
        "adapter_bottleneck": 32,
        "lr_adapter": 1e-4,
        "adapter_dropout": 0.1
    },
    {
        "adapter_bottleneck": 64,
        "lr_adapter": 1e-4,
        "adapter_dropout": 0.1
    },
]

DAN_GRID = [
    {
        "dan_alpha": 0.05,
        "mix_ratio_ft": 1.0
    },
    {
        "dan_alpha": 0.10,
        "mix_ratio_ft": 1.0
    },
    {
        "dan_alpha": 0.10,
        "mix_ratio_ft": 2.0
    },
    {
        "dan_alpha": 0.20,
        "mix_ratio_ft": 2.0
    },
    {
        "dan_alpha": 0.20,
        "mix_ratio_ft": 4.0
    },
]

# Keep grl_lambda=1.0, disc_hidden=128, disc_dropout=0.1 initially.

In [None]:
import copy
import pandas as pd


def tune_special_regions(
    cfg,
    tl_assets,
    base_params,  # your default_params / best_params
    device,
    ch_path,
    regions=("NOR", "ISL", "SVAL"),
    split_name="5pct",
    *,
    tune_adapter=True,
    tune_dan=True,
    adapter_grid=None,
    dan_grid=None,
    # training knobs
    adapter_epochs=60,
    dan_epochs=60,
    force_retrain=True,
):
    if adapter_grid is None:
        adapter_grid = ADAPTER_GRID
    if dan_grid is None:
        dan_grid = DAN_GRID

    results = []
    best_by_region = {}

    for region in regions:
        exp_key = f"TL_CH_to_{region}_{split_name}"
        if exp_key not in tl_assets:
            print(f"[WARN] Missing assets for {exp_key}, skipping.")
            continue

        assets = tl_assets[exp_key]
        if assets.get("ds_finetune", None) is None:
            print(f"[WARN] No finetune dataset for {exp_key}, skipping.")
            continue

        # --------- ADAPTER tuning ----------
        best_adapter = None
        if tune_adapter:
            for cand in adapter_grid:
                params = copy.deepcopy(base_params)
                # model-building knobs
                params["use_adapter"] = True
                params["adapter_bottleneck"] = cand["adapter_bottleneck"]
                params["adapter_dropout"] = cand["adapter_dropout"]
                # keep adapter_domainwise as you like; if True, set n_domains from vocab
                params["adapter_domainwise"] = bool(
                    params.get("adapter_domainwise", True))
                if params["adapter_domainwise"]:
                    dv = assets.get("domain_vocab", None)
                    params["n_domains"] = len(dv) if dv is not None else 1

                model, path, info = finetune_or_load_one_TL(
                    cfg=cfg,
                    exp_key=exp_key,
                    tl_assets_for_key=assets,
                    best_params=params,  # NOTE: pass modified params
                    device=device,
                    pretrained_ckpt_path=ch_path,
                    strategy="adapter",
                    force_retrain=force_retrain,
                    epochs_safe=adapter_epochs,
                    lr_adapter=cand["lr_adapter"],
                    models_dir = "models/GS/"
                )

                best_val = info["best_val"] if info else float("inf")
                row = {
                    "region": region,
                    "split": split_name,
                    "method": "adapter",
                    **cand,
                    "best_val": best_val,
                    "ckpt": path,
                }
                results.append(row)

                if (best_adapter is None) or (best_val
                                              < best_adapter["best_val"]):
                    best_adapter = row

        # --------- DAN tuning ----------
        best_dan = None
        if tune_dan:
            for cand in dan_grid:
                model, path, info = train_dan_one_TL(
                    cfg=cfg,
                    exp_key=exp_key,
                    tl_assets_for_key=assets,
                    best_params=base_params,
                    device=device,
                    pretrained_ckpt_path=ch_path,
                    force_retrain=force_retrain,
                    epochs=dan_epochs,
                    dan_alpha=cand["dan_alpha"],
                    mix_ratio_ft=cand["mix_ratio_ft"],
                    models_dir = "models/GS/"
                )

                best_val = info["best_val"] if info else float("inf")
                row = {
                    "region": region,
                    "split": split_name,
                    "method": "dan",
                    **cand,
                    "best_val": best_val,
                    "ckpt": path,
                }
                results.append(row)

                if (best_dan is None) or (best_val < best_dan["best_val"]):
                    best_dan = row

        best_by_region[region] = {
            "best_adapter": best_adapter,
            "best_dan": best_dan
        }

    df = pd.DataFrame(results).sort_values(["region", "method", "best_val"])
    return best_by_region, df

In [None]:
import json

best_by_region, df_tuning = tune_special_regions(
    cfg=cfg,
    tl_assets=tl_assets,
    base_params=params_by_key["11_CH"],
    device=device,
    ch_path=ch_path,
    regions=("NOR", "ISL", "SVAL"),
    split_name="5pct",
    force_retrain=True,
)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = "results/tuning_adapter"
os.makedirs(save_dir, exist_ok=True)

csv_path = os.path.join(save_dir, f"adapter_tuning_{timestamp}.csv")
df_tuning.to_csv(csv_path, index=False)

print(f"Saved tuning results to: {csv_path}")

json_path = os.path.join(save_dir, f"adapter_best_by_region_{timestamp}.json")

with open(json_path, "w") as f:
    json.dump(best_by_region, f, indent=2)

print(f"Saved best params to: {json_path}")

In [None]:
import json

best_by_region, df_tuning = tune_special_regions(
    cfg=cfg,
    tl_assets=tl_assets,
    base_params=params_by_key["11_CH"],
    device=device,
    ch_path=ch_path,
    regions=["SJM"],
    split_name="5pct",
    force_retrain=True,
)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = "results/tuning_adapter"
os.makedirs(save_dir, exist_ok=True)

csv_path = os.path.join(save_dir, f"adapter_tuning_{timestamp}.csv")
df_tuning.to_csv(csv_path, index=False)

print(f"Saved tuning results to: {csv_path}")

json_path = os.path.join(save_dir, f"adapter_best_by_region_{timestamp}.json")

with open(json_path, "w") as f:
    json.dump(best_by_region, f, indent=2)

print(f"Saved best params to: {json_path}")

### Fine tune LSTM:

In [None]:
models_tl_xreg, infos_tl_xreg = finetune_TL_models_all(
    cfg=cfg,
    tl_assets_by_key=tl_assets,
    best_params=params_by_key["11_CH"],
    device=device,
    pretrained_ckpt_path=ch_path,
    strategies=("safe", "full", "disc_full", "adapter"),
    force_retrain=False,
    prefix="lstm",
    verbose=False)

In [None]:
# Check that domain-wise adapter is activated
m = models_tl_xreg["TL_CH_to_ISL_5pct__adapter"]  # pick any adapter run

print("use_adapter:", getattr(m, "use_adapter", None))
print("adapter_domainwise:", getattr(m, "adapter_domainwise", None))
print("has adapters:", hasattr(m, "adapters"))
print("has single adapter:", hasattr(m, "adapter"))

### DAN:

In [None]:
models_dan, infos_dan = finetune_TL_models_all(
    cfg=cfg,
    tl_assets_by_key=tl_assets,
    best_params=params_by_key["11_CH"],
    device=device,
    pretrained_ckpt_path=ch_path,
    strategies=["dan"],
    force_retrain=False,
    prefix="lstm",
    verbose=True,
    regions_only=["ISL", "NOR", "SJM"],  # only run DAN on these regions
)

### Evaluate on test:

In [None]:
regions = ["FR", "IT_AT", "NOR", "ISL", "SJM"]  # pick any you have models for

models_xreg, paths_xreg = load_xreg_models_all(
    cfg=cfg,
    regions=regions,
    best_params=default_params,
    device=device,
    models_dir="models",
    prefix="lstm_xreg_CH_to",
    date=None,  # auto-detect latest
)

##### 5 percent split (sparse):

In [None]:
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=["FR", "IT_AT"],
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_tl_xreg,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="5pct",  # or "50pct"
    save_dir="figures/eval_TL",
    strategies=["no_ft", "safe", "full", "disc_full", "adapter"],
)

In [None]:
models_total = {**models_tl_xreg, **models_dan}
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=["NOR", "ISL", "SJM"],
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_total,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="5pct",  # or "50pct"
    save_dir="figures/eval_TL",
    strategies=["no_ft", "safe", "full", "adapter", "dan"],
)

#### 50 percent split (moderate):

In [None]:
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=["FR", "IT_AT"],
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_tl_xreg,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="50pct",  # or "50pct"
    save_dir="figures/eval_TL",
    strategies=["no_ft", "safe", "full", "disc_full", "adapter"],
)

In [None]:
models_total = {**models_tl_xreg, **models_dan}
df_tl_grid, preds_tl_grid, fig_tl_grid = evaluate_transfer_learning_grid(
    cfg=cfg,
    regions=["NOR", "ISL", "SJM"],
    models_xreg_by_region=models_xreg,  # baseline CH→Region models
    models_tl_by_key=models_total,  # TL models keyed by "exp__strategy"
    tl_assets_by_key=tl_assets,  # TL assets
    device=device,
    split_name="50pct",  # or "50pct"
    save_dir="figures/eval_TL",
    strategies=["no_ft", "safe", "full", "adapter", "dan"],
)