## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.utils import *
from scripts.glamos import *
from scripts.models import *
from scripts.geo_data import *
from scripts.dataset import *
from scripts.geodetic import *
from scripts.physical import *
from scripts.plotting import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
mbm.utils.seed_all(cfg.seed)
mbm.plots.use_mbm_style()

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    mbm.utils.free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input data:

In [None]:
# Read GLAMOS stake data
data_glamos = get_stakes_data(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly = data_monthly[data_monthly['YEAR']
                            < 2025]  # Used elsewhere for validation

# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly_Aug_ = data_monthly_Aug_[data_monthly_Aug_['YEAR']
                                      < 2025]  # Used elsewhere for validation

## Monthly distributions:

In [None]:
PATH_PREDICTIONS_LSTM_IS = os.path.join(cfg.dataPath, "GLAMOS",
                                        "distributed_MB_grids",
                                        "MBM/paper/LSTM_IS_ORIGINAL_Y_PAST")

PATH_PREDICTIONS_NN = os.path.join(cfg.dataPath, 'GLAMOS',
                                   'distributed_MB_grids', 'MBM/paper/NN')

PATH_PREDICTIONS_XGB = os.path.join(cfg.dataPath, 'GLAMOS',
                                    'distributed_MB_grids', 'MBM/paper/XGB')

hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

In [None]:
CACHE_DIR = os.path.join(
    cfg.dataPath,
    "GLAMOS/distributed_MB_grids/MBM/paper/processed_dfs",
)
os.makedirs(CACHE_DIR, exist_ok=True)

REBUILD_CACHE = False  # set True only when inputs change

paths = {
    "LSTM": os.path.join(CACHE_DIR, "df_months_LSTM.parquet"),
    "NN": os.path.join(CACHE_DIR, "df_months_NN.parquet"),
    "XGB": os.path.join(CACHE_DIR, "df_months_XGB.parquet"),
    "GW": os.path.join(CACHE_DIR, "df_GLAMOS_w.parquet"),
    "GA": os.path.join(CACHE_DIR, "df_GLAMOS_a.parquet"),
}

if REBUILD_CACHE or not all(os.path.exists(p) for p in paths.values()):

    print("Building monthly prediction DataFrames...")

    df_months_LSTM = load_glwd_lstm_predictions(PATH_PREDICTIONS_LSTM_IS,
                                                hydro_months)
    df_months_NN = load_glwd_nn_predictions(PATH_PREDICTIONS_NN, hydro_months)
    df_months_XGB = load_glwd_nn_predictions(PATH_PREDICTIONS_XGB,
                                             hydro_months)

    PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                               "GLAMOS")
    glaciers = os.listdir(PATH_GLAMOS)

    glacier_years = (df_months_LSTM.groupby("glacier")["year"].unique().apply(
        sorted).to_dict())

    df_GLAMOS_w, df_GLAMOS_a = load_glamos_grids(cfg, glacier_years,
                                                 PATH_GLAMOS)

    # Save cache
    df_months_LSTM.to_parquet(paths["LSTM"])
    df_months_NN.to_parquet(paths["NN"])
    df_months_XGB.to_parquet(paths["XGB"])
    df_GLAMOS_w.to_parquet(paths["GW"])
    df_GLAMOS_a.to_parquet(paths["GA"])

    print(f"Cached DataFrames to {CACHE_DIR}")

else:
    print("Loading cached monthly prediction DataFrames...")

    df_months_LSTM = pd.read_parquet(paths["LSTM"])
    df_months_NN = pd.read_parquet(paths["NN"])
    df_months_XGB = pd.read_parquet(paths["XGB"])
    df_GLAMOS_w = pd.read_parquet(paths["GW"])
    df_GLAMOS_a = pd.read_parquet(paths["GA"])

    print(f"Loaded DataFrames from {CACHE_DIR}")

#### Glacier-wide:

In [None]:
# --- 1. Glacier-wide annual mean MB per year ---
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_XGB = df_months_XGB.groupby(['glacier',
                                         'year']).mean().reset_index()
glwd_months_LSTM = df_months_LSTM.groupby(['glacier',
                                           'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier',
                                            'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier',
                                            'year']).mean().reset_index()

# --- 2. Compute the intersection of valid glacier–year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_months_NN['glacier'], glwd_months_NN['year']))
    & set(zip(glwd_months_XGB['glacier'], glwd_months_XGB['year']))
    & set(zip(glwd_months_LSTM['glacier'], glwd_months_LSTM['year']))
    & set(zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))
    & set(zip(glwd_months_GLAMOS_a['glacier'], glwd_months_GLAMOS_a['year'])))


# --- 3. Helper function for filtering by glacier–year pairs ---
def filter_to_valid(df):
    return df[df[['glacier', 'year'
                  ]].apply(tuple,
                           axis=1).isin(valid_pairs)].reset_index(drop=True)


# --- 4. Apply filtering to all datasets ---
glwd_months_NN_filtered = filter_to_valid(glwd_months_NN)
glwd_months_XGB_filtered = filter_to_valid(glwd_months_XGB)
glwd_months_LSTM_filtered = filter_to_valid(glwd_months_LSTM)
glwd_months_GLAMOS_filtered_w = filter_to_valid(glwd_months_GLAMOS_w)
glwd_months_GLAMOS_filtered_a = filter_to_valid(glwd_months_GLAMOS_a)

# --- 5. Prepare for plotting ---
df_months_nn_long = prepare_monthly_long_df(
    glwd_months_LSTM_filtered,
    glwd_months_NN_filtered,
    glwd_months_XGB_filtered,
    glwd_months_GLAMOS_filtered_w,
    glwd_months_GLAMOS_filtered_a,
)

df_months_nn_long.head(2)

In [None]:
# ---- Compute min & max across models being plotted ----
min_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].min().min()
max_ = df_months_nn_long[['mb_nn', 'mb_lstm', 'mb_xgb']].max().max()

# ---- Plot ----
fig = plot_monthly_joyplot(
    df_months_nn_long,
    x_range=(np.floor(min_), np.ceil(max_)),
    color_lstm=mbm.plots.COLOR_ANNUAL,  # or rename to your liking
    color_nn=mbm.plots.COLOR_WINTER,
    color_xgb="darkgreen",
    color_glamos="gray",
)

In [None]:
min_, max_ = df_months_nn_long.min()[[
    'mb_nn', 'mb_glamos'
]].min(), df_months_nn_long.max()[['mb_nn', 'mb_glamos']].max()
fig = plot_monthly_joyplot_single(df_months_nn_long,
                                  variable="mb_lstm",
                                  color_model=mbm.plots.COLOR_WINTER,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name='MBM')
fig.savefig('figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')

### Elevation bands:

#### Highest:

In [None]:
bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# Copy datasets
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# Assign elevation bands
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


def extract_highest_band(df, bin_width):
    max_elev = df.groupby("glacier")["elevation"].transform("max")
    highest_band = df[df["elevation"] >= (max_elev - bin_width)]
    return (highest_band.groupby(["glacier", "year"
                                  ]).mean(numeric_only=True).reset_index())


glwd_high_NN = extract_highest_band(df_months_NN_, bin)
glwd_high_XGB = extract_highest_band(df_months_XGB_, bin)
glwd_high_LSTM = extract_highest_band(df_months_LSTM_, bin)
glwd_high_GLAMOS_a = extract_highest_band(df_GLAMOS_a_, bin)
glwd_high_GLAMOS_w = extract_highest_band(df_GLAMOS_w_, bin)

valid_pairs = (
    set(zip(glwd_high_NN["glacier"], glwd_high_NN["year"]))
    & set(zip(glwd_high_XGB["glacier"], glwd_high_XGB["year"]))
    & set(zip(glwd_high_LSTM["glacier"], glwd_high_LSTM["year"]))
    & set(zip(glwd_high_GLAMOS_w["glacier"], glwd_high_GLAMOS_w["year"]))
    & set(zip(glwd_high_GLAMOS_a["glacier"], glwd_high_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


glwd_high_NN_filt = filter_to_valid(glwd_high_NN)
glwd_high_XGB_filt = filter_to_valid(glwd_high_XGB)
glwd_high_LSTM_filt = filter_to_valid(glwd_high_LSTM)
glwd_high_GLAMOS_a_filt = filter_to_valid(glwd_high_GLAMOS_a)
glwd_high_GLAMOS_w_filt = filter_to_valid(glwd_high_GLAMOS_w)

print(
    len(glwd_high_GLAMOS_w_filt),
    len(glwd_high_GLAMOS_a_filt),
    len(glwd_high_NN_filt),
    len(glwd_high_XGB_filt),
    len(glwd_high_LSTM_filt),
)

df_months_nn_long = prepare_monthly_long_df(
    glwd_high_LSTM_filt,
    glwd_high_NN_filt,
    glwd_high_XGB_filt,
    glwd_high_GLAMOS_w_filt,
    glwd_high_GLAMOS_a_filt,
)

min_, max_ = (
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].min().min(),
    df_months_nn_long[["mb_nn", "mb_lstm", "mb_xgb", "mb_glamos"]].max().max(),
)

fig = plot_monthly_joyplot_single(
    df_months_nn_long,
    variable="mb_lstm",
    color_model=mbm.plots.COLOR_WINTER,
    x_range=(np.floor(min_), np.ceil(max_)),
    model_name="MBM",
)

fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_high_elv.png",
    dpi=300,
    bbox_inches="tight",
)

#### Lowest:

In [None]:
# =========================
# Lowest-elevation band
# =========================

bin = 200
bins = np.arange(1200, 4500, bin)
labels = [f"{b}-{b+bin}" for b in bins[:-1]]

# --- Copy to avoid modifying originals ---
df_months_NN_ = df_months_NN.copy()
df_months_XGB_ = df_months_XGB.copy()
df_months_LSTM_ = df_months_LSTM.copy()
df_GLAMOS_a_ = df_GLAMOS_a.copy()
df_GLAMOS_w_ = df_GLAMOS_w.copy()

# --- Assign elevation bands ---
for df_ in [
        df_months_NN_,
        df_months_XGB_,
        df_months_LSTM_,
        df_GLAMOS_a_,
        df_GLAMOS_w_,
]:
    df_["elev_band"] = pd.cut(df_["elevation"], bins=bins, labels=labels)


# --- Helper: extract lowest-elevation band per glacier ---
def extract_lowest_band(df, bin_width):
    min_elev = df.groupby("glacier")["elevation"].transform("min")
    lowest_band = df[df["elevation"] <= (min_elev + bin_width)]
    return (lowest_band.groupby(["glacier", "year"
                                 ]).mean(numeric_only=True).reset_index())


# --- Compute lowest-elevation bands ---
glwd_low_NN = extract_lowest_band(df_months_NN_, bin)
glwd_low_XGB = extract_lowest_band(df_months_XGB_, bin)
glwd_low_LSTM = extract_lowest_band(df_months_LSTM_, bin)
glwd_low_GLAMOS_a = extract_lowest_band(df_GLAMOS_a_, bin)
glwd_low_GLAMOS_w = extract_lowest_band(df_GLAMOS_w_, bin)

# --- Define common glacier–year pairs ---
valid_pairs = (
    set(zip(glwd_low_NN["glacier"], glwd_low_NN["year"]))
    & set(zip(glwd_low_XGB["glacier"], glwd_low_XGB["year"]))
    & set(zip(glwd_low_LSTM["glacier"], glwd_low_LSTM["year"]))
    & set(zip(glwd_low_GLAMOS_w["glacier"], glwd_low_GLAMOS_w["year"]))
    & set(zip(glwd_low_GLAMOS_a["glacier"], glwd_low_GLAMOS_a["year"])))


def filter_to_valid(df):
    return (df[df[["glacier", "year"
                   ]].apply(tuple,
                            axis=1).isin(valid_pairs)].reset_index(drop=True))


# --- Apply consistent filtering ---
glwd_low_NN_filt = filter_to_valid(glwd_low_NN)
glwd_low_XGB_filt = filter_to_valid(glwd_low_XGB)
glwd_low_LSTM_filt = filter_to_valid(glwd_low_LSTM)
glwd_low_GLAMOS_a_filt = filter_to_valid(glwd_low_GLAMOS_a)
glwd_low_GLAMOS_w_filt = filter_to_valid(glwd_low_GLAMOS_w)

print(
    len(glwd_low_GLAMOS_w_filt),
    len(glwd_low_GLAMOS_a_filt),
    len(glwd_low_NN_filt),
    len(glwd_low_XGB_filt),
    len(glwd_low_LSTM_filt),
)

# --- Prepare long-format dataframe for plotting ---
df_months_nn_long_low = prepare_monthly_long_df(
    glwd_low_LSTM_filt,
    glwd_low_NN_filt,
    glwd_low_XGB_filt,
    glwd_low_GLAMOS_w_filt,
    glwd_low_GLAMOS_a_filt,
)

# --- Determine x-axis limits ---
min_, max_ = (
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].min().min(),
    df_months_nn_long_low[["mb_nn", "mb_lstm", "mb_xgb",
                           "mb_glamos"]].max().max(),
)

# Optional manual clamp for ablation-dominated lowest band
min_ = -10

# --- Plot ---
fig = plot_monthly_joyplot_single(df_months_nn_long_low,
                                  variable="mb_lstm",
                                  color_model=mbm.plots.COLOR_WINTER,
                                  x_range=(np.floor(min_), np.ceil(max_)),
                                  model_name="MBM",
                                  y_offset=0.15)

# --- Save figure ---
fig.savefig(
    "figures/paper/CH_LSTM_vs_GLAMOS_monthly_joyplot_low_elv.png",
    dpi=300,
    bbox_inches="tight",
)

## Feature importance:

#### Aggregated:

In [None]:
# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train = data_monthly[data_monthly.YEAR < 2025]

existing_glaciers = set(data_monthly_train.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly_train[data_monthly_train.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train['y'] = data_train['POINT_BALANCE']

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.YEAR < 2025]

existing_glaciers = set(data_monthly_train_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_train_Aug_[data_monthly_train_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

seed_all(cfg.seed)
ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=True,
                                       expect_target=True)
train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, PARAMS_LSTM_IS_past,
                                                   device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(PARAMS_LSTM_IS_past)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

state = torch.load(LSTM_IS_ORIGIN_Y_PAST, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

# --- Permutation feature importance ---
pfi_full = PFI_LSTM_full(
    model=model,
    device=device,
    ds_test=ds_test_copy,
    monthly_cols=MONTHLY_COLS,
    static_cols=STATIC_COLS,
    n_repeats=8,
    seed=cfg.seed,
    batch_size=128,)
plot_pfi_annual(pfi_full)
plot_pfi_winter(pfi_full)

### Monthly:

In [None]:
# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train = data_monthly[data_monthly.YEAR < 2025]

existing_glaciers = set(data_monthly_train.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly_train[data_monthly_train.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train['y'] = data_train['POINT_BALANCE']

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"
# remove 2025
data_monthly_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.YEAR < 2025]

existing_glaciers = set(data_monthly_train_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_train_Aug_[data_monthly_train_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

seed_all(cfg.seed)
ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=False,
                                       expect_target=True)
train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, PARAMS_LSTM_IS_past,
                                                   device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(PARAMS_LSTM_IS_past)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

state = torch.load(LSTM_IS_ORIGIN_Y_PAST, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
month_names = [
    "aug_", "sep_", "oct", "nov", "dec", "jan", "feb", "mar", "apr", "may",
    "jun", "jul", "aug", "sep", "oct_"
]

RUN_PFI_MONTHLY = False
if RUN_PFI_MONTHLY:

    pfi_monthly = PFI_LSTM_monthly_parallel(
        model=model,
        device=device,
        ds_test=ds_test_copy,
        monthly_cols=MONTHLY_COLS,
        static_cols=STATIC_COLS,
        month_names=month_names,
        n_repeats=8,
        n_jobs=12,
        seed=cfg.seed,
    )

    # save to cache
    pfi_monthly.to_csv('cache/pfi_LSTM_IS_monthly_absolute.csv', index=False)

else:
    pfi_monthly = pd.read_csv('cache/pfi_LSTM_IS_monthly_absolute.csv')

In [None]:
vars_to_plot = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
]
plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="winter",
    drop_padded_months=True,
    fname="figures/paper/CH_LSTM_monthly_PFI_winter.png",
    title="Monthly Permutation Feature Importance – Winter")

plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="annual",
    drop_padded_months=True,
    fname="figures/paper/CH_LSTM_monthly_PFI_annual.png",
    title="Monthly Permutation Feature Importance – Annual")

plot_monthly_pfi_ridges(
    pfi_monthly,
    vars_to_plot,
    vois_climate_long_name,
    months_tail_pad,
    months_head_pad,
    metric="global",
    drop_padded_months=True,
    fname="figures/paper/fig9_CH_LSTM_monthly_PFI_global.png",
    title="Monthly Permutation Feature Importance")