In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from typing import Optional, Iterable, Dict, List
import os

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.EuropeTFConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cross-regional modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"]

In [None]:
# run it
summarize_and_plot_all_regions(dfs)

In [None]:
# run it
plot_mb_distributions_all_regions(dfs)

### Monthly datasets:

In [None]:
# Now test glaciers are all their glaciers:
TEST_GLACIERS_SJM = load_stakes(cfg, "SJM").GLACIER.unique().tolist()
TEST_GLACIERS_ISL = load_stakes(cfg, "ISL").GLACIER.unique().tolist()
TEST_GLACIERS_NOR = load_stakes(cfg, "NOR").GLACIER.unique().tolist()
TEST_GLACIERS_FR = load_stakes(cfg, "FR").GLACIER.unique().tolist()
TEST_GLACIERS_IT_AT = load_stakes(cfg, "IT_AT").GLACIER.unique().tolist()

TEST_GLACIERS_BY_CODE = {
    "SJM": TEST_GLACIERS_SJM,
    "ISL": TEST_GLACIERS_ISL,
    "NOR": TEST_GLACIERS_NOR,
    "FR": TEST_GLACIERS_FR,
    # "CH": TEST_GLACIERS_CH,
    "IT_AT": TEST_GLACIERS_IT_AT,
}

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

    vois_climate = [
        "t2m",
        "tp",
        "slhf",
        "sshf",
        "ssrd",
        "fal",
        "str",
    ]

vois_topographical = ["aspect", "slope", "svf"]

In [None]:
def build_crossregional_df_ceu_with_ch(dfs: dict) -> pd.DataFrame:
    """
    Concatenate all stake dataframes in `dfs` into one Europe-wide dataframe.

    Expects:
      - Each df has at least columns: GLACIER, YEAR, ID, PERIOD, MONTHS, POINT_BALANCE
      - Central Europe df includes SOURCE_CODE identifying CH/FR/IT_AT etc.

    Returns
    -------
    pd.DataFrame
        Combined dataframe (all rows across all RGI regions).
    """
    frames = []
    for rid, df in dfs.items():
        if df is None or len(df) == 0:
            logging.warning(f"RGI {rid}: empty, skipping in concat.")
            continue
        frames.append(df)

    if not frames:
        raise ValueError("No non-empty dataframes in dfs.")

    d_all = pd.concat(frames, ignore_index=True)
    return d_all


def compute_crossregional_test_glaciers(
    df_all: pd.DataFrame,
    ch_code: str = "CH",
    source_col: str = "SOURCE_CODE",
    glacier_col: str = "GLACIER",
):
    """
    Train glaciers = all glaciers with SOURCE_CODE == CH
    Test glaciers  = all glaciers with SOURCE_CODE != CH

    Returns
    -------
    (train_glaciers, test_glaciers) : (list[str], list[str])
    """
    if source_col not in df_all.columns:
        raise ValueError(
            f"Missing column {source_col}. Needed to separate CH vs others.")
    if glacier_col not in df_all.columns:
        raise ValueError(f"Missing column {glacier_col}.")

    ch_gl = sorted(df_all.loc[df_all[source_col] == ch_code,
                              glacier_col].dropna().unique())
    non_ch_gl = sorted(df_all.loc[df_all[source_col] != ch_code,
                                  glacier_col].dropna().unique())

    if not ch_gl:
        raise ValueError("No CH glaciers found (SOURCE_CODE=='CH').")
    if not non_ch_gl:
        raise ValueError("No non-CH glaciers found (SOURCE_CODE!='CH').")

    logging.info(
        f"Cross-regional split: CH train glaciers={len(ch_gl)}, non-CH test glaciers={len(non_ch_gl)}"
    )
    return ch_gl, non_ch_gl


def prepare_monthly_df_crossregional_CH_to_EU(
    cfg,
    dfs,
    paths,
    vois_climate,
    vois_topographical,
    run_flag=True,  # True recompute, False load
    region_name="XREG_CH_TO_EU",
    region_id=11,  # arbitrary/int tag used by your pipeline; keep 11 or 0
    csv_subfolder="CrossRegional/CH_to_Europe/csv",
):
    """
    Build ONE monthly-prepped dataset:
      - data = concatenation of all Europe sources
      - train = CH glaciers
      - test  = all non-CH glaciers

    Returns
    -------
    res : dict
        Same output dict as prepare_monthly_dfs_with_padding (df_train/df_test/aug/etc.)
    split_info : dict
        {"train_glaciers": [...], "test_glaciers": [...]}
    """

    # 1) Concatenate all raw stake rows
    df_all = build_crossregional_df_ceu_with_ch(dfs)

    # 2) Define test glaciers: all non-CH
    train_glaciers, test_glaciers = compute_crossregional_test_glaciers(
        df_all, ch_code="CH")

    # 3) Choose an output folder for this experiment
    paths_ = paths.copy()
    paths_["csv_path"] = os.path.join(cfg.dataPath, path_PMB_WGMS_csv,
                                      csv_subfolder)
    os.makedirs(paths_["csv_path"], exist_ok=True)

    logging.info(
        f"Preparing cross-regional monthlies: {region_name} "
        f"(run_flag={run_flag}) | train(CH)={len(train_glaciers)} | test(non-CH)={len(test_glaciers)}"
    )

    res = prepare_monthly_dfs_with_padding(
        cfg=cfg,
        df_region=df_all,
        region_name=region_name,
        region_id=int(region_id),
        paths=paths_,
        test_glaciers=test_glaciers,  # test = all non-CH glaciers
        vois_climate=vois_climate,
        vois_topographical=vois_topographical,
        run_flag=run_flag,
    )

    return res, {
        "train_glaciers": train_glaciers,
        "test_glaciers": test_glaciers
    }

In [None]:
# load all stake dfs
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# prepare monthlies (recompute or load)
res_xreg, split_info = prepare_monthly_df_crossregional_CH_to_EU(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,  # load if already computed
)

df_train = res_xreg["df_train"]
df_test = res_xreg["df_test"]

print("Train glaciers (CH):", len(split_info["train_glaciers"]))
print("Test glaciers (non-CH):", len(split_info["test_glaciers"]))
print("Train rows:", len(df_train), "Test rows:", len(df_test))

#### Feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
def plot_tsne_overlap_xreg_from_single_res(
        res_xreg: dict,
        cfg,
        STATIC_COLS,
        MONTHLY_COLS,
        group_col: str = "SOURCE_CODE",
        ch_code: str = "CH",
        use_aug: bool = False,  # True -> df_train_aug/df_test_aug
        n_iter: int = 1000,
        only_codes=None,  # e.g. ["IT_AT", "FR"]
        skip_codes=None,  # e.g. ["CH"]
):
    """
    For XREG where train=CH and test=all non-CH inside ONE monthly result dict:

      - df_ch = res_xreg[df_train*]
      - df_test_all = res_xreg[df_test*]
      - split df_test_all by SOURCE_CODE and plot CH vs each code

    Returns dict: code -> figure
    """
    only_codes = {c.upper() for c in (only_codes or [])} or None
    skip_codes = {c.upper() for c in (skip_codes or [])}
    skip_codes.add(ch_code.upper())

    # pick which dfs
    if use_aug:
        df_ch = res_xreg.get("df_train_aug")
        df_test_all = res_xreg.get("df_test_aug")
        label_df = "(*_aug)"
    else:
        df_ch = res_xreg.get("df_train")
        df_test_all = res_xreg.get("df_test")
        label_df = ""

    if df_ch is None or len(df_ch) == 0:
        raise ValueError(f"df_train{label_df} missing/empty in res_xreg.")
    if df_test_all is None or len(df_test_all) == 0:
        raise ValueError(f"df_test{label_df} missing/empty in res_xreg.")

    if group_col not in df_test_all.columns:
        raise ValueError(
            f"'{group_col}' not found in df_test{label_df}. Needed to split by region."
        )
    if group_col not in df_ch.columns:
        # not fatal, but helps sanity-check
        print(
            f"[warn] '{group_col}' not in df_train{label_df}. That's OK for CH reference."
        )

    # palette
    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    custom_palette = {"Train": color_dark_blue, "Test": "#b2182b"}

    # codes present in test
    codes_present = sorted(c for c in df_test_all[group_col].dropna().astype(
        str).str.upper().unique() if c not in skip_codes)

    if only_codes is not None:
        codes_present = [c for c in codes_present if c in only_codes]

    figs = {}
    for code in codes_present:
        df_other = df_test_all[df_test_all[group_col].astype(str).str.upper()
                               == code].copy()
        if len(df_other) == 0:
            continue

        print(
            f"Plotting XREG t-SNE: CH(train n={len(df_ch)}) vs {code}(test n={len(df_other)})"
        )

        fig = plot_tsne_overlap(
            data_train=df_ch,
            data_test=df_other,
            STATIC_COLS=STATIC_COLS,
            MONTHLY_COLS=MONTHLY_COLS,
            sublabels=("a", "b", "c"),
            label_fmt="({})",
            label_xy=(0.02, 0.98),
            label_fontsize=14,
            n_iter=n_iter,
            random_state=cfg.seed,
            custom_palette=custom_palette,
        )
        fig.suptitle(f"XREG overlap: CH vs {code}", fontsize=14)
        figs[code] = fig

    return figs


# res_xreg is the ONE dict from your cross-regional monthly prep
figs_by_code = plot_tsne_overlap_xreg_from_single_res(
    res_xreg=res_xreg,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=False,  # or True if you want *_aug
    n_iter=1000,
    # only_codes=["IT_AT"],  # optional
)

In [None]:
import os


def plot_feature_kde_overlap_xreg_ch_vs_codes(
    res_xreg: dict,
    cfg,
    features,
    group_col: str = "SOURCE_CODE",
    ch_code: str = "CH",
    use_aug: bool = False,  # True -> df_train_aug/df_test_aug
    only_codes=None,  # e.g. ["IT_AT", "FR"]
    skip_codes=None,  # e.g. ["CH"]
    output_dir=None,  # e.g. "figures/xreg_kde"
    include_ch_in_title: bool = True,
):
    """
    Plot KDE-based feature overlap for XREG: CH vs each SOURCE_CODE subset.

    Uses:
      - CH reference: res_xreg["df_train"] (or "_aug" if use_aug)
      - Other region: subset of res_xreg["df_test"] by SOURCE_CODE

    Parameters
    ----------
    res_xreg : dict
        Output dict from prepare_monthly_df_crossregional_CH_to_EU (or similar),
        containing df_train/df_test and optionally df_train_aug/df_test_aug.
        df_test must contain `group_col` (SOURCE_CODE).
    cfg : object
        Used only for consistent output naming if desired (optional).
    features : list[str]
        Feature columns to plot.
    group_col : str
        Column to split test set by (default: "SOURCE_CODE").
    ch_code : str
        Code identifying CH (default: "CH").
    use_aug : bool
        If True uses df_train_aug/df_test_aug.
    only_codes : list[str] or None
        If given, only plot these codes.
    skip_codes : list[str] or None
        Codes to skip (CH is always skipped by default).
    output_dir : str or None
        If set, saves one PNG per code into this directory.
    include_ch_in_title : bool
        Adds CH vs CODE title on each figure.

    Returns
    -------
    dict
        code -> matplotlib Figure
    """
    # palette (reuse your consistent colors)
    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    palette = {
        "Train": color_dark_blue,
        "Test": "#b2182b"
    }  # Train=CH, Test=Other

    ch_code = str(ch_code).upper()
    only_set = {c.upper() for c in only_codes} if only_codes else None
    skip_set = {c.upper() for c in (skip_codes or [])}
    skip_set.add(ch_code)

    if use_aug:
        df_ch = res_xreg.get("df_train_aug")
        df_test_all = res_xreg.get("df_test_aug")
        suffix = "_aug"
    else:
        df_ch = res_xreg.get("df_train")
        df_test_all = res_xreg.get("df_test")
        suffix = ""

    if df_ch is None or len(df_ch) == 0:
        raise ValueError(f"Missing/empty df_train{suffix} in res_xreg.")
    if df_test_all is None or len(df_test_all) == 0:
        raise ValueError(f"Missing/empty df_test{suffix} in res_xreg.")
    if group_col not in df_test_all.columns:
        raise ValueError(f"'{group_col}' not found in df_test{suffix}.")

    codes = sorted(
        df_test_all[group_col].dropna().astype(str).str.upper().unique())
    codes = [c for c in codes if c not in skip_set]
    if only_set is not None:
        codes = [c for c in codes if c in only_set]

    if output_dir:
        out_abs = os.path.join(cfg.dataPath, output_dir) if hasattr(
            cfg, "dataPath") else output_dir
        os.makedirs(out_abs, exist_ok=True)
    else:
        out_abs = None

    figs = {}

    for code in codes:
        df_other = df_test_all[df_test_all[group_col].astype(str).str.upper()
                               == code].copy()
        if len(df_other) == 0:
            continue

        print(
            f"Plotting XREG KDE: CH(train n={len(df_ch)}) vs {code}(test n={len(df_other)})"
        )

        fig = plot_feature_kde_overlap(
            df_train=df_ch,
            df_test=df_other,
            features=features,
            palette=palette,
            outfile=None,  # save here instead (so we control naming)
        )

        if include_ch_in_title:
            fig.suptitle(f"XREG feature overlap: CH vs {code}", fontsize=14)
            fig.tight_layout()

        if out_abs:
            out_png = os.path.join(
                out_abs, f"xreg_kde_overlap_CH_vs_{code}{suffix}.png")
            fig.savefig(out_png, dpi=300, bbox_inches="tight")

        figs[code] = fig

    return figs


In [None]:
FEATURES = MONTHLY_COLS + STATIC_COLS + ["POINT_BALANCE"]

figs_kde = plot_feature_kde_overlap_xreg_ch_vs_codes(
    res_xreg=res_xreg,
    cfg=cfg,
    features=FEATURES,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=True,  # usually best for feature overlap
    # only_codes=["IT_AT", "FR"],    # optional
    output_dir="figures/xreg_kde",  # optional
)


## LSTM model
### LSTM datasets:

In [None]:
def build_crossregional_res_all(
    res_xreg: dict,
    target_source_codes=None,
    source_col="SOURCE_CODE",
    ch_code="CH",
    key_prefix="XREG_CH_TO",
):
    """
    Returns res_all dict: {key: res_like_dict}
    where each res contains df_train (CH), df_test (only that target region),
    plus *_aug and pads.
    """
    df_test = res_xreg["df_test"]
    if source_col not in df_test.columns:
        raise ValueError(f"Missing {source_col} in res_xreg['df_test'].")

    if target_source_codes is None:
        target_source_codes = sorted(
            set(df_test[source_col].dropna().unique()) - {ch_code})

    res_all = {}
    for sc in target_source_codes:
        key = f"{key_prefix}_{sc}"

        res_sc = {
            "df_train":
            res_xreg["df_train"],
            "df_train_aug":
            res_xreg["df_train_aug"],
            "df_test":
            res_xreg["df_test"].loc[res_xreg["df_test"][source_col] ==
                                    sc].copy(),
            "df_test_aug":
            res_xreg["df_test_aug"].loc[res_xreg["df_test_aug"][source_col] ==
                                        sc].copy(),
            "months_head_pad":
            res_xreg["months_head_pad"],
            "months_tail_pad":
            res_xreg["months_tail_pad"],
        }

        res_all[key] = res_sc

    return res_all

In [None]:
res_all_xreg = build_crossregional_res_all(
    res_xreg=res_xreg,
    target_source_codes=None,  # auto-discover from df_test
    source_col="SOURCE_CODE",
    key_prefix="XREG_CH_TO",
)
res_all_xreg.keys()

In [None]:
# outputs_xreg = build_or_load_lstm_all(
#     cfg=cfg,
#     res_all=res_all_xreg,
#     MONTHLY_COLS=MONTHLY_COLS,
#     STATIC_COLS=STATIC_COLS,
#     cache_dir="logs/LSTM_cache",
#     only_keys=None,  # e.g. ["XREG_CH_TO_NO"] to recompute only Norway
#     force_recompute=False,  # global default: load if cached
#     val_ratio=0.2,
# )

In [None]:
def _check_for_nans(key,
                    df_loss,
                    df_full,
                    monthly_cols,
                    static_cols,
                    strict=True):
    """
    Checks for NaNs/Infs in features and targets.
    Raises ValueError if strict=True, otherwise prints warning.
    """
    feat_cols = [
        c for c in (monthly_cols + static_cols) if c in df_full.columns
    ]

    # --- feature NaNs ---
    n_nan_feat = df_full[feat_cols].isna().sum().sum()
    n_inf_feat = np.isinf(df_full[feat_cols].to_numpy(dtype="float64",
                                                      copy=False)).sum()

    # --- target NaNs ---
    n_nan_target = df_loss["POINT_BALANCE"].isna().sum()
    n_inf_target = np.isinf(df_loss["POINT_BALANCE"].to_numpy(
        dtype="float64", copy=False)).sum()

    if any([n_nan_feat, n_inf_feat, n_nan_target, n_inf_target]):

        msg = (f"[{key}] Data integrity issue:\n"
               f"  Feature NaNs: {n_nan_feat}\n"
               f"  Feature Infs: {n_inf_feat}\n"
               f"  Target  NaNs: {n_nan_target}\n"
               f"  Target  Infs: {n_inf_target}")

        if strict:
            raise ValueError(msg)
        else:
            warnings.warn(msg)


def _lstm_cache_paths(cfg, key: str, cache_dir: str):
    out_dir = os.path.join(cache_dir)
    os.makedirs(out_dir, exist_ok=True)
    train_p = os.path.join(out_dir, f"{key}_train.joblib")
    test_p = os.path.join(out_dir, f"{key}_test.joblib")
    split_p = os.path.join(out_dir, f"{key}_split.joblib")
    return train_p, test_p, split_p


def build_or_load_lstm_train_only(
    cfg,
    key_train: str,
    res_train: dict,  # must contain df_train, df_train_aug, pads
    MONTHLY_COLS,
    STATIC_COLS,
    val_ratio=0.2,
    cache_dir="logs/LSTM_cache",
    force_recompute=False,
    normalize_target=True,
    expect_target=True,
    strict_nan=True,
):
    train_p, _, split_p = _lstm_cache_paths(cfg,
                                            key_train,
                                            cache_dir=cache_dir)

    if (not force_recompute) and all(
            os.path.exists(p) for p in [train_p, split_p]):
        ds_train = joblib.load(train_p)
        split = joblib.load(split_p)
        return ds_train, split["train_idx"], split["val_idx"]

    df_train = res_train["df_train"]
    df_train_aug = res_train["df_train_aug"]
    months_head_pad = res_train["months_head_pad"]
    months_tail_pad = res_train["months_tail_pad"]

    _check_for_nans(
        key_train,
        df_loss=df_train,
        df_full=df_train_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        strict=strict_nan,
    )

    mbm.utils.seed_all(cfg.seed)

    ds_train = build_combined_LSTM_dataset(
        df_loss=df_train,
        df_full=df_train_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=val_ratio, seed=cfg.seed)

    joblib.dump(ds_train, train_p, compress=3)
    joblib.dump({
        "train_idx": train_idx,
        "val_idx": val_idx
    },
                split_p,
                compress=3)

    return ds_train, train_idx, val_idx


def build_or_load_lstm_test_only(
    cfg,
    key_test: str,
    res_test: dict,  # must contain df_test, df_test_aug, pads
    MONTHLY_COLS,
    STATIC_COLS,
    cache_dir="logs/LSTM_cache",
    force_recompute=False,
    normalize_target=True,
    expect_target=True,
    strict_nan=True,
):
    _, test_p, _ = _lstm_cache_paths(cfg, key_test, cache_dir=cache_dir)

    if (not force_recompute) and os.path.exists(test_p):
        return joblib.load(test_p)

    df_test = res_test["df_test"]
    df_test_aug = res_test["df_test_aug"]
    months_head_pad = res_test["months_head_pad"]
    months_tail_pad = res_test["months_tail_pad"]

    _check_for_nans(
        key_test,
        df_loss=df_test,
        df_full=df_test_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        strict=strict_nan,
    )

    mbm.utils.seed_all(cfg.seed)

    ds_test = build_combined_LSTM_dataset(
        df_loss=df_test,
        df_full=df_test_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    joblib.dump(ds_test, test_p, compress=3)
    return ds_test


def build_or_load_lstm_all_crossregional(
    cfg,
    res_xreg: dict,
    MONTHLY_COLS,
    STATIC_COLS,
    target_source_codes=None,
    source_col="SOURCE_CODE",
    ch_code="CH",
    cache_dir="logs/LSTM_cache",
    force_recompute_train=False,
    force_recompute_tests=False,
    only_test_keys=None,
    val_ratio=0.2,
    normalize_target=True,
    expect_target=True,
    strict_nan=True,
):

    logging.info("\n" + "=" * 60)
    logging.info("CROSS-REGIONAL LSTM DATASET PREPARATION (CH → EU)")
    logging.info("=" * 60)

    # ---- discover target codes ----
    df_test_all = res_xreg["df_test"]

    if target_source_codes is None:
        target_source_codes = sorted(
            set(df_test_all[source_col].dropna().unique()) - {ch_code})
        logging.info(
            f"Auto-detected target SOURCE_CODEs (excluding {ch_code}): "
            f"{target_source_codes}")
    else:
        logging.info(
            f"Using provided target SOURCE_CODEs: {target_source_codes}")

    logging.info(f"Total target regions: {len(target_source_codes)}")
    logging.info(f"Cache directory: {cache_dir}")

    # ---- train (CH) cached once ----
    key_train = "XREG_CH_TRAIN"

    logging.info("\n--- CH TRAIN DATASET ---")
    logging.info(f"Cache key: {key_train}")
    logging.info(f"Force recompute train: {force_recompute_train}")

    res_train = {
        "df_train": res_xreg["df_train"],
        "df_train_aug": res_xreg["df_train_aug"],
        "months_head_pad": res_xreg["months_head_pad"],
        "months_tail_pad": res_xreg["months_tail_pad"],
    }

    logging.info(f"CH train rows: {len(res_train['df_train'])} | "
                 f"Aug rows: {len(res_train['df_train_aug'])}")

    ds_train, train_idx, val_idx = build_or_load_lstm_train_only(
        cfg=cfg,
        key_train=key_train,
        res_train=res_train,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        val_ratio=val_ratio,
        cache_dir=cache_dir,
        force_recompute=force_recompute_train,
        normalize_target=normalize_target,
        expect_target=expect_target,
        strict_nan=strict_nan,
    )

    logging.info(f"CH train dataset size: {len(ds_train)} | "
                 f"Train split: {len(train_idx)} | Val split: {len(val_idx)}")

    # ---- tests cached per target ----
    logging.info("\n--- TARGET REGION TEST DATASETS ---")

    outputs = {}
    only_set = set(only_test_keys) if only_test_keys else None

    for sc in target_source_codes:

        fr_test = force_recompute_tests
        if only_set is not None:
            fr_test = (sc in only_set) or (f"XREG_CH_TO_{sc}" in only_set)

        logging.info("\n" + "-" * 50)
        logging.info(f"Target region: {sc}")
        logging.info(f"Force recompute test: {fr_test}")

        df_test_sc = res_xreg["df_test"].loc[res_xreg["df_test"][source_col] ==
                                             sc].copy()

        df_test_aug_sc = res_xreg["df_test_aug"].loc[
            res_xreg["df_test_aug"][source_col] == sc].copy()

        logging.info(f"Test rows: {len(df_test_sc)} | "
                     f"Aug rows: {len(df_test_aug_sc)}")

        if len(df_test_sc) == 0 or len(df_test_aug_sc) == 0:
            logging.warning(f"Skipping {sc}: no usable test rows.")
            outputs[sc] = {
                "ds_train": ds_train,
                "ds_test": None,
                "train_idx": train_idx,
                "val_idx": val_idx,
                "note": f"No test rows for SOURCE_CODE={sc}",
            }
            continue

        res_sc = {
            "df_test": df_test_sc,
            "df_test_aug": df_test_aug_sc,
            "months_head_pad": res_xreg["months_head_pad"],
            "months_tail_pad": res_xreg["months_tail_pad"],
        }

        key_test = f"XREG_CH_TO_{sc}"
        logging.info(f"Cache key (test): {key_test}")

        ds_test = build_or_load_lstm_test_only(
            cfg=cfg,
            key_test=key_test,
            res_test=res_sc,
            MONTHLY_COLS=MONTHLY_COLS,
            STATIC_COLS=STATIC_COLS,
            cache_dir=cache_dir,
            force_recompute=fr_test,
            normalize_target=normalize_target,
            expect_target=expect_target,
            strict_nan=strict_nan,
        )

        logging.info(f"Test dataset size (sequences): {len(ds_test)}")

        outputs[sc] = {
            "ds_train": ds_train,
            "ds_test": ds_test,
            "train_idx": train_idx,
            "val_idx": val_idx,
            "cache_keys": {
                "train": key_train,
                "test": key_test,
            },
        }

    logging.info("\nFinished cross-regional LSTM dataset preparation.")
    logging.info("=" * 60 + "\n")

    return outputs

In [None]:
outputs_xreg = build_or_load_lstm_all_crossregional(
    cfg=cfg,
    res_xreg=res_xreg,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache",
)

### LSTM parameters:

In [None]:
default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

### Train model:

In [None]:
models_xreg, infos_xreg = train_crossregional_models_all(
    cfg=cfg,
    lstm_assets_by_key=outputs_xreg,  # from build_or_load_lstm_all_crossregional
    default_params=default_params,
    device=device,
    train_keys=None,  # train/load all targets
    force_retrain=True,
    models_dir="models",
    prefix="lstm_xreg_CH_to",
    epochs=150,
)

### Evaluate on test:

In [None]:
def evaluate_all_models_crossregional(
        cfg,
        models_by_key: dict,
        lstm_assets_by_key: dict,
        device,
        save_dir=None,
        grid_shape=(2, 3),
        grid_figsize=(20, 12),
        ax_xlim=(-16, 9),
        ax_ylim=(-16, 9),
        title_prefix="CH → ",
        file_prefix="CH_to_",
        order=None,  # e.g. ["FR","IT","AT","NO","SE","IS"]
):
    # ---- collect valid keys ----
    valid_keys = []
    for k in models_by_key.keys():
        m = models_by_key.get(k, None)
        a = lstm_assets_by_key.get(k, None)
        if m is None:
            continue
        if a is None or a.get("ds_test", None) is None:
            continue
        valid_keys.append(k)

    valid_set = set(valid_keys)

    # ---- apply custom order (if provided) ----
    if order is None:
        keys = sorted(valid_keys)
    else:
        # keep only valid keys in the requested order
        ordered = [k for k in order if k in valid_set]
        # append any remaining valid keys not mentioned in order, deterministically
        remainder = sorted(valid_set - set(ordered))
        keys = ordered + remainder

        missing_in_outputs = [k for k in order if k not in valid_set]
        if missing_in_outputs:
            logging.warning(
                "Some requested order keys were not evaluated (missing model or ds_test): "
                + ", ".join(missing_in_outputs))

    if len(keys) == 0:
        raise ValueError(
            "No valid (model, ds_test) pairs found for evaluation.")

    # directory
    if save_dir:
        save_abs = os.path.join(save_dir)
        os.makedirs(save_abs, exist_ok=True)
    else:
        save_abs = None

    nrows, ncols = grid_shape
    fig_grid, axes = plt.subplots(nrows,
                                  ncols,
                                  figsize=grid_figsize,
                                  sharex=True,
                                  sharey=True)
    axes = np.array(axes).reshape(-1)
    n_slots = len(axes)

    rows = []
    preds_by_key = {}
    figs_by_key = {}

    for i, key in enumerate(keys):
        model = models_by_key[key]
        print(f"\nEvaluating CH -> {key} ...")

        # --- individual ---
        metrics, df_preds, fig_ind, _ax_ind = evaluate_one_model(
            cfg=cfg,
            model=model,
            device=device,
            lstm_assets_for_key=lstm_assets_by_key[key],
            ax=None,
            ax_xlim=ax_xlim,
            ax_ylim=ax_ylim,
            title=f"{title_prefix}{key} – Pred vs Truth (Test)",
            legend_fontsize=14,
        )

        metrics["key"] = key
        rows.append(metrics)
        preds_by_key[key] = df_preds
        figs_by_key[key] = fig_ind

        if save_abs:
            out_png = os.path.join(save_abs,
                                   f"pred_vs_truth_{file_prefix}{key}.png")
            fig_ind.savefig(out_png, dpi=200, bbox_inches="tight")
        plt.close(fig_ind)

        # --- grid subplot ---
        if i < n_slots:
            ax_grid = axes[i]
            evaluate_one_model(
                cfg=cfg,
                model=model,
                device=device,
                lstm_assets_for_key=lstm_assets_by_key[key],
                ax=ax_grid,
                ax_xlim=ax_xlim,
                ax_ylim=ax_ylim,
                title=f"{title_prefix}{key}",
                legend_fontsize=15,
            )

    # turn off unused axes
    for j in range(len(keys), n_slots):
        axes[j].axis("off")

    fig_grid.suptitle("Pred vs Truth (Test) — Cross-regional (train CH)",
                      fontsize=20)
    fig_grid.tight_layout()

    if save_abs:
        out_grid = os.path.join(
            save_abs, f"pred_vs_truth_{file_prefix}ALL_targets_grid.png")
        fig_grid.savefig(out_grid, dpi=200, bbox_inches="tight")

    df_metrics = pd.DataFrame(rows).set_index("key").sort_index()

    return df_metrics, preds_by_key, figs_by_key, fig_grid

In [None]:
custom_order = ["FR", "IT_AT", "NOR", "SJM", "ISL"]  # whatever you want
df_metrics_xreg, preds_xreg, figs_xreg, fig_grid_xreg = evaluate_all_models_crossregional(
    cfg=cfg,
    models_by_key=models_xreg,
    lstm_assets_by_key=outputs_xreg,
    device=device,
    save_dir="figures/eval_xreg",
    grid_shape=(2, 3),
    order=custom_order,
)