In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from typing import Optional, Iterable, Dict, List
import os

import pandas as pd
import numpy as np
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import massbalancemachine as mbm
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch 
from matplotlib.lines import Line2D

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

cfg = mbm.EuropeTFConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cross-regional modelling:

### Read stakes datasets:

In [None]:
"""
Examples of loading data:
# Load Switzerland only
df = load_stakes(cfg, "CH")

# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}"""

# Load all Europe regions configured
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}
dfs["11"]

### Monthly datasets:

In [None]:
# Now test glaciers are all their glaciers:
TEST_GLACIERS_SJM = load_stakes(cfg, "SJM").GLACIER.unique().tolist()
TEST_GLACIERS_ISL = load_stakes(cfg, "ISL").GLACIER.unique().tolist()
TEST_GLACIERS_NOR = load_stakes(cfg, "NOR").GLACIER.unique().tolist()
TEST_GLACIERS_FR = load_stakes(cfg, "FR").GLACIER.unique().tolist()
TEST_GLACIERS_IT_AT = load_stakes(cfg, "IT_AT").GLACIER.unique().tolist()

TEST_GLACIERS_BY_CODE = {
    "SJM": TEST_GLACIERS_SJM,
    "ISL": TEST_GLACIERS_ISL,
    "NOR": TEST_GLACIERS_NOR,
    "FR": TEST_GLACIERS_FR,
    "IT_AT": TEST_GLACIERS_IT_AT,
}

In [None]:
# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

    vois_climate = [
        "t2m",
        "tp",
        "slhf",
        "sshf",
        "ssrd",
        "fal",
        "str",
    ]

vois_topographical = ["aspect", "slope", "svf"]

In [None]:
# load all stake dfs
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# prepare monthlies (recompute or load)
res_xreg, split_info = prepare_monthly_df_xreg_CH_to_EU(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,  # load if already computed
)

df_train = res_xreg["df_train"]
df_test = res_xreg["df_test"]

print("Train glaciers (CH):", len(split_info["train_glaciers"]))
print("Test glaciers (non-CH):", len(split_info["test_glaciers"]))
print("Train rows:", len(df_train), "Test rows:", len(df_test))

#### Plot feature overlap:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect', 'slope', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

In [None]:
# res_xreg is the ONE dict from your cross-regional monthly prep
figs_by_code = plot_tsne_overlap_xreg_from_single_res(
    res_xreg=res_xreg,
    cfg=cfg,
    STATIC_COLS=STATIC_COLS,
    MONTHLY_COLS=MONTHLY_COLS,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=False,  # or True if you want *_aug
    n_iter=1000,
    # only_codes=["IT_AT"],  # optional
)

In [None]:
FEATURES = MONTHLY_COLS + STATIC_COLS + ["POINT_BALANCE"]

figs_kde = plot_feature_kde_overlap_xreg_ch_vs_codes(
    res_xreg=res_xreg,
    cfg=cfg,
    features=FEATURES,
    group_col="SOURCE_CODE",
    ch_code="CH",
    use_aug=True,  # usually best for feature overlap
    # only_codes=["IT_AT", "FR"],    # optional
    output_dir="figures/xreg_kde",  # optional
)


## LSTM model
### LSTM datasets:

In [None]:
res_all_xreg = build_xreg_res_all(
    res_xreg=res_xreg,
    target_source_codes=None,  # auto-discover from df_test
    source_col="SOURCE_CODE",
    key_prefix="XREG_CH_TO",
)
res_all_xreg.keys()

In [None]:
outputs_xreg = build_or_load_lstm_all_xreg(
    cfg=cfg,
    res_xreg=res_xreg,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    cache_dir="logs/LSTM_cache",
)

### LSTM parameters:

In [None]:
default_params = {
    'Fm': 8,
    'Fs': 3,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 1,
    'static_hidden': 128,
    'static_dropout': 0.1,
    'lr': 0.001,
    'weight_decay': 1e-05,
    'loss_name': 'neutral',
    'two_heads': False,
    'head_dropout': 0.1,
    'loss_spec': None
}

### Train model:

In [None]:
models_xreg, infos_xreg = train_crossregional_models_all(
    cfg=cfg,
    lstm_assets_by_key=outputs_xreg,  # from build_or_load_lstm_all_crossregional
    default_params=default_params,
    device=device,
    train_keys=None,  # train/load all targets
    force_retrain=True,
    models_dir="models",
    prefix="lstm_xreg_CH_to",
    epochs=150,
)

### Evaluate on test:

In [None]:
custom_order = ["FR", "IT_AT", "NOR", "SJM", "ISL"]  # whatever you want
df_metrics_xreg, preds_xreg, figs_xreg, fig_grid_xreg = evaluate_all_models_xreg(
    cfg=cfg,
    models_by_key=models_xreg,
    lstm_assets_by_key=outputs_xreg,
    device=device,
    save_dir="figures/eval_xreg",
    grid_shape=(2, 3),
    order=custom_order,
)