In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from typing import Optional, Iterable, Dict, List
import os

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.EuropeConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Within region modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"]

In [None]:
# run it
summarize_and_plot_all_regions(dfs)

In [None]:
# run it
plot_mb_distributions_all_regions(dfs)

### Monthly datasets:

In [None]:
# Test glaciers
# TEST_GLACIERS_SJM = ['WERENSKIOLDBREEN']
TEST_GLACIERS_SJM = ['GROENFJORD E', 'WERENSKIOLDBREEN']

TEST_GLACIERS_CH = [
    "tortin",
    "plattalva",
    "schwarzberg",
    "hohlaub",
    "sanktanna",
    "corvatsch",
    "tsanfleuron",
    "forno",
]

# TEST_GLACIERS_NOR = [
#     'Cainhavarre', 'Rundvassbreen', 'Svartisheibreen', 'Trollbergdalsbreen',
#     'Hansebreen', 'Tunsbergdalsbreen', 'Austdalsbreen', 'Hellstugubreen',
#     'Austre Memurubreen', 'Bondhusbrea', 'Svelgjabreen', 'Moesevassbrea',
#     'Blomstoelskardsbreen'
# ]

TEST_GLACIERS_NOR = [
    'Moesevassbrea', 'Vetlefjordbreen', 'Juvfonne', 'Graasubreen',
    'Hellstugubreen', 'Storglombreen N', 'Blabreen', 'Ruklebreen',
    'Vestre Memurubreen', 'Cainhavarre', 'Bondhusbrea'
]

# TEST_GLACIERS_IT_AT = [
#     'GOLDBERG K.',
#     'HINTEREIS F.',
#     'JAMTAL F.',
#     'VERNAGT F.',
# ]

# TEST_GLACIERS_IT_AT = [
#     'GOLDBERG K.',
#     'HINTEREIS F.',
#     'JAMTAL F.',
#     'VERNAGT F.',
# ]

# TEST_GLACIERS_ISL = [
#     'RGI60-06.00311', 'RGI60-06.00305', 'Thjorsarjoekull (Hofsjoekull E)',
#     'RGI60-06.00445', 'RGI60-06.00474', 'RGI60-06.00425', 'RGI60-06.00480',
#     'Dyngjujoekull', 'RGI60-06.00478', 'Koeldukvislarjoekull',
#     'Oeldufellsjoekull', 'RGI60-06.00350', 'RGI60-06.00340'
# ]

TEST_GLACIERS_ISL = [
    'RGI60-06.00306', 'RGI60-06.00296', 'RGI60-06.00479', 'RGI60-06.00425',
    'RGI60-06.00445', 'RGI60-06.00474', 'RGI60-06.00542',
    'Reykjafjardarjoekull', 'RGI60-06.00350', 'RGI60-06.00342',
    'RGI60-06.00301', 'RGI60-06.00422', 'RGI60-06.00320', 'RGI60-06.00359',
    'RGI60-06.00349', 'RGI60-06.00409', 'RGI60-06.00413', 'RGI60-06.00411',
    'Oeldufellsjoekull', 'RGI60-06.00476', 'RGI60-06.00549', 'RGI60-06.00228',
    'RGI60-06.00303', 'Kaldalonsjoekull', 'RGI60-06.00328', 'RGI60-06.00541',
    'Slettjoekull West', 'RGI60-06.00232', 'RGI60-06.00305'
]

# TEST_GLACIERS_FR = ['Talefre', 'Argentiere', 'Gebroulaz']
TEST_GLACIERS_FR = ['Grands Montets', 'Sarennes', 'Talefre', 'Leschaux']

TEST_GLACIERS_IT_AT = [
    'CIARDONEY', 'CARESER CENTRALE', 'CAMPO SETT.', 'ZETTALUNITZ/MULLWITZ K.',
    'HALLSTAETTER G.', 'VENEDIGER K.', 'SURETTA MERIDIONALE', 'GOLDBERG K.',
    'CARESER OCCIDENTALE', 'GRAND ETRET', 'LUPO'
]

TEST_GLACIERS_BY_CODE = {
    "SJM": TEST_GLACIERS_SJM,
    "ISL": TEST_GLACIERS_ISL,
    "NOR": TEST_GLACIERS_NOR,
    "FR": TEST_GLACIERS_FR,
    "CH": TEST_GLACIERS_CH,
    "IT_AT": TEST_GLACIERS_IT_AT,
}

In [None]:
df_IT_AT = load_stakes(cfg, "IT_AT")
# all glaciers
df_IT_AT.GLACIER.unique()

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

    vois_climate = [
        "t2m",
        "tp",
        "slhf",
        "sshf",
        "ssrd",
        "fal",
        "str",
    ]

vois_topographical = ["aspect", "slope", "svf"]

# Example: Only recompute IT_AT
res_all = prepare_monthly_dfs_for_all_regions(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,
    test_glaciers_by_code=TEST_GLACIERS_BY_CODE,
    # only_codes=["IT_AT"],  # only IT_AT recomputes (careful, only codes overrides run_flag)
)

# Optional: rerun parts only
"""
Example: Only recompute Norway
res_all = prepare_monthlies_for_all_regions(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=True,
    only_codes=["NOR"],   # only Norway recomputes
)"""

#### Feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# Example usage:
# res_all is what you got from prepare_monthlies_for_all_regions(...)
figs = plot_overlap_for_all_results(
    results_dict=res_all,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    n_iter=1000,
)

In [None]:
figs_kde = plot_feature_overlap_all_regions(res_all, STATIC_COLS, MONTHLY_COLS)

## LSTM model
### LSTM datasets:

In [None]:
lstm_assets = build_or_load_lstm_all(cfg=cfg,
                                     res_all=res_all,
                                     MONTHLY_COLS=MONTHLY_COLS,
                                     STATIC_COLS=STATIC_COLS,
                                     cache_dir="logs/LSTM_cache",
                                     force_recompute=False)
"""Example: 
# only recompute Norway, load others if cached
lstm_assets = build_or_load_lstm_all(
    cfg=cfg,
    res_all=res_all,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    only_keys=["08_NOR"],
)"""

### LSTM parameters:

In [None]:
log_path_gs_results = {
    "ISL": 'logs/GS_results/lstm_param_search_progress_OOS_ISL_2026-02-11.csv',
    "NOR": 'logs/GS_results/lstm_param_search_progress_OOS_NOR_2026-02-09.csv',
    "FR": 'logs/GS_results/lstm_param_search_progress_OOS_FR_2026-02-06.csv',
    "CH": 'logs/GS_results/lstm_param_search_progress_CH_2026-02-18.csv',
}

default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

params_by_key = build_lstm_params_by_key(
    default_params=default_params,
    log_path_gs_results=log_path_gs_results,
    RGI_REGIONS=RGI_REGIONS,
)

params_by_key.keys()

### Train model:

In [None]:
# models, infos = train_within_region_models_all(
#     cfg=cfg,
#     lstm_assets_by_key=lstm_assets,
#     params_by_key=params_by_key,
#     device=device,
#     epochs=150,
# )

# Only retrain Norway, load the rest:
models, infos = train_within_region_models_all(
    cfg=cfg,
    lstm_assets_by_key=lstm_assets,
    params_by_key=params_by_key,
    device=device,
    # train_keys=["11_IT_AT"],
    epochs=150,
    force_retrain=False
)

### Evaluate on test:

In [None]:
df_metrics, preds_by_key, figs_by_key, fig_grid = evaluate_all_models(
    cfg=cfg,
    models_by_key=models,
    lstm_assets_by_key=lstm_assets,
    device=device,
    save_dir="figures/eval_within_region",
    ax_xlim=(-10, 10),
    ax_ylim=(-10, 10)
)

display(df_metrics)