In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from typing import Optional, Iterable, Dict, List
import os

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

# from regions.Svalbard.scripts.config_SVA import *
# from regions.Svalbard.scripts.dataset import get_stakes_data_SVA
# from regions.Svalbard.scripts.utils import *

# from regions.Switzerland.scripts.dataset import process_or_load_data, get_CV_splits
# from regions.Switzerland.scripts.plotting import plot_predictions_summary, plot_individual_glacier_pred, plot_history_lstm, get_cmap_hex,plot_tsne_overlap, plot_feature_kde_overlap, alpha_labels, pred_vs_truth_density
# from regions.Switzerland.scripts.dataset import get_stakes_data, build_combined_LSTM_dataset, inspect_LSTM_sample, prepare_monthly_dfs_with_padding
# from regions.Switzerland.scripts.models import compute_seasonal_scores, get_best_params_for_lstm

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.EuropeConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Within region modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}"""

In [None]:
# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"]

In [None]:
import matplotlib.pyplot as plt


def summarize_and_plot_all_regions(dfs):
    for rid, df in dfs.items():
        if df is None or len(df) == 0:
            print(f"\n=== RGI {rid}: empty ===")
            continue

        d = df.copy()

        # keep only annual+winter
        if "PERIOD" in d.columns:
            d = d[d["PERIOD"].isin(["annual", "winter"])].copy()

        print(f"\n========== RGI {rid} ==========")

        # --- glaciers per subregion (SOURCE_CODE) ---
        if "SOURCE_CODE" in d.columns and "GLACIER" in d.columns:
            glaciers_per_sub = (
                d.groupby("SOURCE_CODE")["GLACIER"].nunique().sort_values(
                    ascending=False))
            print("Unique glaciers per subregion (SOURCE_CODE):")
            print(glaciers_per_sub)
        else:
            print(
                "[warn] Missing SOURCE_CODE and/or GLACIER columns; skipping glacier counts."
            )

        # --- stacked bars per year for each subregion ---
        if not {"YEAR", "PERIOD"}.issubset(d.columns):
            print("[warn] Missing YEAR/PERIOD columns; skipping plots.")
            continue

        group_key = "SOURCE_CODE" if "SOURCE_CODE" in d.columns else None
        if group_key is None:
            # no subregions: treat everything as one group
            groups = [("ALL", d)]
        else:
            groups = list(d.groupby(group_key))

        for code, dsub in groups:
            counts = (dsub.groupby(["YEAR", "PERIOD"
                                    ]).size().unstack(fill_value=0).reindex(
                                        columns=["annual", "winter"],
                                        fill_value=0).sort_index())

            plt.figure(figsize=(20, 6))
            plt.bar(counts.index,
                    counts["annual"].values,
                    label="annual",
                    color=mbm.plots.COLOR_ANNUAL)
            plt.bar(counts.index,
                    counts["winter"].values,
                    bottom=counts["annual"].values,
                    label="winter",
                    color=mbm.plots.COLOR_WINTER)
            plt.title(
                f"RGI {rid} – {code}: measurements per year (annual + winter)")
            plt.xlabel("Year")
            plt.ylabel("Number of measurements")
            plt.legend()
            plt.tight_layout()
            plt.show()


# run it
summarize_and_plot_all_regions(dfs)

In [None]:
def plot_mb_distributions_all_regions(
    dfs,
    periods=("annual", "winter"),
    value_col="POINT_BALANCE",
    group_col="SOURCE_CODE",
    bins_n=21,
):
    for rid, df in dfs.items():
        if df is None or len(df) == 0:
            print(f"\n=== RGI {rid}: empty ===")
            continue

        if not {"PERIOD", value_col}.issubset(df.columns):
            print(
                f"\n=== RGI {rid}: missing PERIOD or {value_col}, skipping ==="
            )
            continue

        # keep only the periods we want
        d = df[df["PERIOD"].isin(periods)].copy()

        # choose grouping
        if group_col in d.columns:
            groups = list(d[group_col].dropna().unique())
            groups = sorted(groups)
        else:
            groups = ["ALL"]
            d[group_col] = "ALL"

        # build plot
        fig, axes = plt.subplots(1, len(periods), figsize=(14, 5), sharey=True)
        if len(periods) == 1:
            axes = [axes]

        for ax, period in zip(axes, periods):
            # Collect all values across groups to define common bins
            vals_all = []
            for g in groups:
                vals = d.loc[(d["PERIOD"] == period) & (d[group_col] == g),
                             value_col].dropna().values
                if vals.size:
                    vals_all.append(vals)

            if not vals_all:
                ax.set_title(f"{period.capitalize()} Mass Balance (no data)")
                ax.set_xlabel("Mass balance [m w.e.]")
                continue

            vals_all = np.concatenate(vals_all)
            vmin, vmax = float(vals_all.min()), float(vals_all.max())
            if np.isclose(vmin, vmax):
                # degenerate case: all values identical
                bins = np.linspace(vmin - 1e-6, vmax + 1e-6, bins_n)
            else:
                bins = np.linspace(vmin, vmax, bins_n)

            # Plot each group
            for g in groups:
                vals = d.loc[(d["PERIOD"] == period) & (d[group_col] == g),
                             value_col].dropna().values
                if not vals.size:
                    continue
                ax.hist(vals, bins=bins, alpha=0.5, label=str(g))
                ax.axvline(vals.mean(), linestyle="--")

            ax.set_title(f"{period.capitalize()} Mass Balance")
            ax.set_xlabel("Mass balance [m w.e.]")
            ax.legend()

        axes[0].set_ylabel("Number of measurements")
        plt.suptitle(f"RGI {rid} – Seasonal Point Mass Balance Distribution",
                     fontsize=14)
        plt.tight_layout()
        plt.show()


# run it
plot_mb_distributions_all_regions(dfs)

### Monthly datasets:

In [None]:
# Test glaciers
TEST_GLACIERS_SJM = ['WERENSKIOLDBREEN']

TEST_GLACIERS_CH = [
    "tortin",
    "plattalva",
    "schwarzberg",
    "hohlaub",
    "sanktanna",
    "corvatsch",
    "tsanfleuron",
    "forno",
]

TEST_GLACIERS_NOR = [
    'Cainhavarre', 'Rundvassbreen', 'Svartisheibreen', 'Trollbergdalsbreen',
    'Hansebreen', 'Tunsbergdalsbreen', 'Austdalsbreen', 'Hellstugubreen',
    'Austre Memurubreen', 'Bondhusbrea', 'Svelgjabreen', 'Moesevassbrea',
    'Blomstoelskardsbreen'
]

TEST_GLACIERS_IT_AT = [
    'GOLDBERG K.', 'HALLSTAETTER G.', 'HINTEREIS F.', 'JAMTAL F.',
    'KESSELWAND F.', 'KLEINFLEISS K.', 'OE. WURTEN K.', 'VENEDIGER K.',
    'VERNAGT F.', 'ZETTALUNITZ/MULLWITZ K.'
]

TEST_GLACIERS_ISL = [
    'RGI60-06.00311', 'RGI60-06.00305', 'Thjorsarjoekull (Hofsjoekull E)',
    'RGI60-06.00445', 'RGI60-06.00474', 'RGI60-06.00425', 'RGI60-06.00480',
    'Dyngjujoekull', 'RGI60-06.00478', 'Koeldukvislarjoekull',
    'Oeldufellsjoekull', 'RGI60-06.00350', 'RGI60-06.00340'
]

TEST_GLACIERS_FR = ['Talefre', 'Argentiere', 'Gebroulaz']

TEST_GLACIERS_BY_CODE = {
    "SJM": TEST_GLACIERS_SJM,
    "ISL": TEST_GLACIERS_ISL,
    "NOR": TEST_GLACIERS_NOR,
    "FR": TEST_GLACIERS_FR,
    "CH": TEST_GLACIERS_CH,
    "IT_AT": TEST_GLACIERS_IT_AT,
}

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

In [None]:
def codes_for_rgi_region(rid: str) -> list[str]:
    rid = str(rid).zfill(2)
    spec = RGI_REGIONS[rid]
    # If subregions exist, run per subregion code
    if spec.get("subregions_codes"):
        return [c.upper() for c in spec["subregions_codes"]]
    # Otherwise run the region code
    return [spec["code"].upper()]


def country_folder_for_code(rid: str, code: str) -> str:
    """
    Returns the WGMS country folder name for a given RGI region id + code.
    """
    rid = str(rid).zfill(2)
    spec = RGI_REGIONS[rid]

    # Case 1: region has subregions
    if spec.get("subregions_codes"):
        codes = spec["subregions_codes"]
        names = spec["subregions"]
        mapping = dict(zip(codes, names))
        return mapping.get(code)

    # Case 2: single-region (no subregions)
    return spec["name"]


def prepare_monthlies_for_all_regions(
    cfg,
    dfs,
    paths,
    vois_climate,
    vois_topographical,
    run_flag=True,
    only_rids=None,  # e.g. ["08"]
    only_codes=None,  # e.g. ["NOR"]
    test_glaciers_override=None,
):
    results = {}

    only_rids_set = {str(r).zfill(2) for r in only_rids} if only_rids else None
    only_codes_set = {c.upper() for c in only_codes} if only_codes else None
    test_glaciers_override = test_glaciers_override or {}

    for rid, df_region in dfs.items():
        rid2 = str(rid).zfill(2)

        if df_region is None or len(df_region) == 0:
            print(f"Skipping RGI {rid2}: empty dataframe")
            continue

        region_id_int = int(rid2)
        codes = [c.upper() for c in codes_for_rgi_region(rid2)]
        has_source = "SOURCE_CODE" in df_region.columns

        for code in codes:

            # Decide whether this one should recompute
            should_run = run_flag
            if only_rids_set or only_codes_set:
                match_rid = (only_rids_set is None or rid2 in only_rids_set)
                match_code = (only_codes_set is None or code in only_codes_set)
                should_run = match_rid and match_code

            # Slice df
            if has_source:
                df_sub = df_region[df_region["SOURCE_CODE"] == code].copy()
            else:
                df_sub = df_region.copy()

            if len(df_sub) == 0:
                print(f"[RGI {rid2}] No rows for code={code}, skipping.")
                continue

            # Override test glaciers if requested
            test_glaciers = test_glaciers_override.get(
                code, TEST_GLACIERS_BY_CODE.get(code, []))

            # Get country folder
            country = country_folder_for_code(rid2, code)

            # Build csv path
            paths_ = paths.copy()
            paths_["csv_path"] = os.path.join(cfg.dataPath, path_PMB_WGMS_csv,
                                              country, "csv")

            print(f"\nProcessing RGI {rid2} / {code} "
                  f"(country={country}, run_flag={should_run})")

            res = prepare_monthly_dfs_with_padding(
                cfg=cfg,
                df_region=df_sub,
                region_name=code,
                region_id=region_id_int,
                paths=paths_,
                test_glaciers=test_glaciers,
                vois_climate=vois_climate,
                vois_topographical=vois_topographical,
                run_flag=should_run,
            )

            results[f"{rid2}_{code}"] = res

    return results

In [None]:
res_all = prepare_monthlies_for_all_regions(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=False,
)

# Optional: rerun parts only
"""
Example: Only recompute Norway
res_all = prepare_monthlies_for_all_regions(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    run_flag=True,
    only_codes=["NOR"],   # only Norway recomputes
)"""

#### Feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
def plot_overlap_for_all_results(
    results_dict,
    cfg,
    STATIC_COLS,
    MONTHLY_COLS,
    n_iter=1000,
):
    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

    figs = {}

    for key, res in results_dict.items():
        if res is None:
            continue

        df_train = res.get("df_train")
        df_test = res.get("df_test")

        if df_train is None or df_test is None or len(df_train) == 0 or len(
                df_test) == 0:
            print(f"[{key}] Missing/empty df_train or df_test, skipping.")
            continue

        print(
            f"Plotting t-SNE overlap for {key}: train={len(df_train)}, test={len(df_test)}"
        )

        fig = plot_tsne_overlap(
            df_train,
            df_test,
            STATIC_COLS,
            MONTHLY_COLS,
            sublabels=("a", "b", "c"),
            label_fmt="({})",
            label_xy=(0.02, 0.98),
            label_fontsize=14,
            n_iter=n_iter,
            random_state=cfg.seed,
            custom_palette=custom_palette,
        )

        figs[key] = fig

    return figs


# Example usage:
# res_all is what you got from prepare_monthlies_for_all_regions(...)
figs = plot_overlap_for_all_results(
    results_dict=res_all,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    n_iter=1000,
)

In [None]:
import os
import matplotlib.pyplot as plt
from IPython.display import display


def plot_feature_overlap_all_regions(
    results_dict,
    STATIC_COLS,
    MONTHLY_COLS,
    output_dir="figures",
    include_target=True,
):
    os.makedirs(output_dir, exist_ok=True)

    colors = get_cmap_hex(cm.batlow, 10)
    color_dark_blue = colors[0]
    palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

    features = STATIC_COLS + MONTHLY_COLS
    if include_target:
        features = features + ["POINT_BALANCE"]

    figs = {}

    for key, res in results_dict.items():
        if res is None:
            continue

        df_train = res.get("df_train")
        df_test = res.get("df_test")

        if df_train is None or df_test is None:
            print(f"[{key}] Missing df_train/df_test, skipping.")
            continue

        if len(df_train) == 0 or len(df_test) == 0:
            print(f"[{key}] Empty train/test, skipping.")
            continue

        print(f"Plotting KDE overlap for {key}")

        fig = plot_feature_kde_overlap(df_train,
                                       df_test,
                                       features,
                                       palette,
                                       outfile=None)

        figs[key] = fig

    return figs


figs_kde = plot_feature_overlap_all_regions(res_all, STATIC_COLS, MONTHLY_COLS)

## Train LSTM:

In [None]:
import joblib


def _lstm_cache_paths(cfg, key: str, cache_dir: str):
    out_dir = os.path.join(cache_dir)
    os.makedirs(out_dir, exist_ok=True)
    train_p = os.path.join(out_dir, f"{key}_train.joblib")
    test_p = os.path.join(out_dir, f"{key}_test.joblib")
    split_p = os.path.join(out_dir, f"{key}_split.joblib")
    return train_p, test_p, split_p


def build_or_load_lstm_for_key(
    cfg,
    key: str,
    res: dict,
    MONTHLY_COLS,
    STATIC_COLS,
    val_ratio=0.2,
    cache_dir="logs/LSTM_cache",
    force_recompute=False,
    normalize_target=True,
    expect_target=True,
):
    train_p, test_p, split_p = _lstm_cache_paths(cfg, key, cache_dir=cache_dir)

    if (not force_recompute) and all(
            os.path.exists(p) for p in [train_p, test_p, split_p]):
        ds_train = joblib.load(train_p)
        ds_test = joblib.load(test_p)
        split = joblib.load(split_p)
        return ds_train, ds_test, split["train_idx"], split["val_idx"]

    # required pieces from your monthly prep
    df_train = res["df_train"]
    df_test = res["df_test"]
    df_train_aug = res["df_train_aug"]
    df_test_aug = res["df_test_aug"]
    months_head_pad = res["months_head_pad"]
    months_tail_pad = res["months_tail_pad"]

    mbm.utils.seed_all(cfg.seed)

    ds_train = build_combined_LSTM_dataset(
        df_loss=df_train,
        df_full=df_train_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    ds_test = build_combined_LSTM_dataset(
        df_loss=df_test,
        df_full=df_test_aug,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        normalize_target=normalize_target,
        expect_target=expect_target,
    )

    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=val_ratio, seed=cfg.seed)

    # save
    joblib.dump(ds_train, train_p, compress=3)
    joblib.dump(ds_test, test_p, compress=3)
    joblib.dump({
        "train_idx": train_idx,
        "val_idx": val_idx
    },
                split_p,
                compress=3)

    return ds_train, ds_test, train_idx, val_idx

In [None]:
def build_or_load_lstm_all(
    cfg,
    res_all: dict,  # e.g. {"07_SJM": res, "08_NOR": res, ...}
    MONTHLY_COLS,
    STATIC_COLS,
    cache_dir="logs/LSTM_cache",
    only_keys=None,  # e.g. ["08_NOR"] to recompute only Norway
    force_recompute=False,  # global default
    val_ratio=0.2,
):
    outputs = {}
    only_keys_set = set(only_keys) if only_keys else None

    for key, res in res_all.items():
        if res is None:
            continue

        # recompute only some keys; others load if possible
        fr = force_recompute
        if only_keys_set is not None:
            fr = key in only_keys_set

        print(f"\nLSTM prep: {key} (force_recompute={fr})")

        ds_train, ds_test, train_idx, val_idx = build_or_load_lstm_for_key(
            cfg=cfg,
            key=key,
            res=res,
            MONTHLY_COLS=MONTHLY_COLS,
            STATIC_COLS=STATIC_COLS,
            val_ratio=val_ratio,
            cache_dir=cache_dir,
            force_recompute=fr,
        )

        outputs[key] = {
            "ds_train": ds_train,
            "ds_test": ds_test,
            "train_idx": train_idx,
            "val_idx": val_idx,
        }

    return outputs

In [None]:
lstm_assets = build_or_load_lstm_all(
    cfg=cfg,
    res_all=res_all,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache",
)
"""Example: 
# only recompute Norway, load others if cached
lstm_assets = build_or_load_lstm_all(
    cfg=cfg,
    res_all=res_all,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    only_keys=["08_NOR"],
)"""

In [None]:
import os
import copy


def all_codes_from_config(RGI_REGIONS: dict) -> list[str]:
    codes = set()

    for rid, spec in RGI_REGIONS.items():
        sub_codes = spec.get("subregions_codes", []) or []

        if sub_codes:
            # If subregions exist → only add those
            codes.update(c.upper() for c in sub_codes)
        else:
            # Otherwise add the region-level code
            codes.add(spec["code"].upper())

    return sorted(codes)


def build_lstm_params_by_code(
    default_params: dict,
    log_path_gs_results: dict,
    RGI_REGIONS: dict,
    select_by: str = "avg_test_loss",
):
    """
    Returns dict: code -> params

    For every code in RGI_REGIONS (region codes + subregion codes):
      - if a grid-search log exists: load best params and override defaults
      - else: use defaults
    """
    params_by_code = {}
    all_codes = all_codes_from_config(RGI_REGIONS)

    for code in all_codes:
        params = copy.deepcopy(default_params)

        log_path = log_path_gs_results.get(code)
        if log_path and os.path.exists(log_path):
            print(f"Loading tuned params for {code} from {log_path}")
            best_params = get_best_params_for_lstm(log_path,
                                                   select_by=select_by)
            params.update(best_params)
        else:
            print(f"No grid-search log for {code}. Using default params.")

        params_by_code[code] = params

    return params_by_code


log_path_gs_results = {
    "ISL": 'logs/GS_results/lstm_param_search_progress_OOS_ISL_2026-02-11.csv',
    "NOR": 'logs/GS_results/lstm_param_search_progress_OOS_NOR_2026-02-09.csv',
    "FR": 'logs/GS_results/lstm_param_search_progress_OOS_FR_2026-02-06.csv',
}

default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

params_by_code = build_lstm_params_by_code(
    default_params=default_params,
    log_path_gs_results=log_path_gs_results,
    RGI_REGIONS=RGI_REGIONS,
)

params_by_code.keys()