## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
    "millan_v", "svf"
]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = True
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf_IS.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

### Blocking on glaciers:

Model is trained on all glaciers --> "Within sample"


In [None]:
existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

# splits, test_set, train_set = get_CV_splits(dataloader_gl,
#                                             test_split_on='GLACIER',
#                                             test_splits=TEST_GLACIERS,
#                                             random_state=cfg.seed)

# print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
#                                       test_set['splits_vals']))
# test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
# print('Percentage of test size: {:.2f}%'.format(test_perc))
# print('Size of test set:', len(test_set['df_X']))
# print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
#                                        train_set['splits_vals']))
# print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
    'pcsr'
]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', "svf"
]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# Look at the padding for one example
key = ('adler', 2009, 11, 'winter')

# find the index of this key
try:
    idx = ds_train.keys.index(key)
except ValueError:
    raise ValueError(f"Key {key} not found in dataset.")

# fetch the corresponding sequence
sequence = ds_train[idx]
sequence['mv'], sequence['mw'], sequence['ma']

### Define & train model:

In [None]:
log_path = 'logs/lstm_two_heads_param_search_progress_no_oggm_IS_2025-10-22.csv'
best_params = get_best_params_for_lstm(log_path, select_by='test_rmse_a')
df = pd.read_csv(log_path)
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="avg_test_loss", inplace=True)
print(best_params)
df.head(10)

In [None]:
# custom_params = {
#     'Fm': 8,
#     'Fs': 6,
#     'hidden_size': 128,
#     'num_layers': 2,
#     'bidirectional': False,
#     'dropout': 0.1,
#     'static_layers': 2,
#     'static_hidden': [128, 64],
#     'static_dropout': 0.1,
#     'lr': 0.001,
#     'weight_decay': 0.0,
#     'loss_name': 'neutral',
#     'two_heads': True,
#     'head_dropout': 0.0,
#     'loss_spec': None
# }

custom_params = best_params

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_no_oggm_IS.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)
    
model_filename = f"models/lstm_model_2025-10-23_no_oggm_IS.pt"

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

train_glaciers = {
    'adler', 'albigna', 'aletsch', 'allalin', 'basodino', 'clariden',
    'corbassiere', 'corvatsch', 'findelen', 'forno', 'gietro', 'gorner',
    'gries', 'hohlaub', 'joeri', 'limmern', 'morteratsch', 'murtel', 'oberaar',
    'otemma', 'pizol', 'plattalva', 'rhone', 'sanktanna', 'schwarzberg',
    'sexrouge', 'silvretta', 'tortin', 'tsanfleuron'
}
test_gl_per_el = gl_per_el[list(train_glaciers)].sort_values().index

fig, axs = plt.subplots(7, 3, figsize=(25, 30), sharex=False)

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       color_annual=color_dark_blue,
                                       color_winter=color_pink,
                                       custom_order=test_gl_per_el,
                                       add_text=True,
                                       ax_xlim=None,
                                       gl_area = gl_area)


## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
from scripts.parallel_mb import MBJobConfig, run_glacier_mb

path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/glamos_dems_LSTM_no_oggm_IS')
job = MBJobConfig(
    cfg=cfg,
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    fields_not_features=cfg.fieldsNotFeatures,
    model_filename=model_filename,
    custom_params=custom_params,
    ds_train=ds_train,
    train_idx=train_idx,
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    data_path=cfg.dataPath,
    path_glacier_grid_glamos=path_glacier_grid_glamos,
    path_xr_grids=os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM',
                               'xr_masked_grids'),
    path_save_glw=path_save_glw,
    seed=cfg.seed,
    max_workers=None,  # or an int
    cpu_only=True,
    ONLY_GEODETIC=False)

# 3) Run
summary = run_glacier_mb(job, glacier_list, periods_per_glacier)
print("SUMMARY:", summary)

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                                  df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

fig = plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)

## Permutation feature importance:

In [None]:
# Parallel PFI for MBM LSTM (CPU, Linux)
import os, sys, random, numpy as np, pandas as pd, torch, xarray as xr
from typing import Dict, List, Callable, Any, Tuple
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
import multiprocessing as mp
from tqdm.auto import tqdm

# ------------------------------- determinism helpers -------------------------------


def _set_cpu_env_threads():
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")  # force CPU
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("MKL_NUM_THREADS", "1")
    os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
    os.environ.setdefault("NUMEXPR_MAX_THREADS", "1")
    try:
        torch.set_num_threads(1)
    except Exception:
        pass


def _set_seeds(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # deterministic CPU path
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass


# ------------------------------- globals for worker -------------------------------
# Filled by initializer once per process
_PFI_G = {
    "model": None,
    "device": None,
    "ds_train_copy": None,
    "MONTHLY_COLS": None,
    "STATIC_COLS": None,
    "months_head_pad": None,
    "months_tail_pad": None,
    "target_col": None,
    "baseline": None,
    "df_eval": None,
    "seed": None,
    "batch_size": None,
}


def _pfi_worker_init(cfg, custom_params: Dict[str, Any], model_filename: str,
                     ds_train, train_idx, MONTHLY_COLS: List[str],
                     STATIC_COLS: List[str], months_head_pad: int,
                     months_tail_pad: int, df_eval: pd.DataFrame,
                     target_col: str, seed: int, batch_size: int):
    """Initializer: quiet stdout, set threads, build scalers, load model, compute baseline once."""
    # local import avoids pickling a module from parent
    import massbalancemachine as mbm

    # silence worker prints
    sys.stdout = open(os.devnull, "w")
    sys.stderr = open(os.devnull, "w")

    _set_cpu_env_threads()
    _set_seeds(seed)

    # Fit scalers on TRAIN only
    ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train)
    ds_train_copy.fit_scalers(train_idx)

    # Load model on CPU
    device = torch.device("cpu")
    model = mbm.models.LSTM_MB.build_model_from_params(cfg,
                                                       custom_params,
                                                       device,
                                                       verbose=False)
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state)
    model.eval()

    # Build eval ds/loader with targets
    ds_eval = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_eval,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
        show_progress=False)
    dl_eval = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_eval, ds_train_copy, seed=seed, batch_size=batch_size)
    with torch.no_grad():
        df_pred_base = model.predict_with_keys(device, dl_eval, ds_eval)

    # Try to pick targets directly from pred df, otherwise from df_eval by ID merge
    if _PFI_G["target_col"] is None:
        pass  # set below
    y_true = None
    if target_col in df_pred_base.columns:
        y_true = df_pred_base[target_col].to_numpy()
    else:
        merged = df_pred_base.merge(df_eval[["ID", target_col
                                             ]].drop_duplicates("ID"),
                                    on="ID",
                                    how="left")
        if merged[target_col].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval contains per-ID targets."
            )
        y_true = merged[target_col].to_numpy()

    y_pred = df_pred_base["pred"].to_numpy()
    baseline = float(np.sqrt(np.mean((y_true - y_pred)**2)))

    # store globals
    _PFI_G.update(
        dict(model=model,
             device=device,
             ds_train_copy=ds_train_copy,
             MONTHLY_COLS=MONTHLY_COLS,
             STATIC_COLS=STATIC_COLS,
             months_head_pad=months_head_pad,
             months_tail_pad=months_tail_pad,
             target_col=target_col,
             baseline=baseline,
             df_eval=df_eval,
             seed=seed,
             batch_size=batch_size))


def _permute_within_groups(values: np.ndarray, groups: np.ndarray,
                           rng: np.random.Generator) -> np.ndarray:
    out = np.empty_like(values)
    # group by group label; shuffle within each block
    # to preserve seasonal distribution
    u, inv = np.unique(groups, return_inverse=True)
    for gi, g in enumerate(u):
        idx = np.where(inv == gi)[0]
        shuf = idx.copy()
        rng.shuffle(shuf)
        out[idx] = values[shuf]
    return out


def _pfi_worker_task(task: Tuple[str, str, int]) -> Tuple[str, str, float]:
    """
    Task = (feature, type, repeat_seed_offset) -> returns (feature, type, delta_rmse)
    """
    # local import
    import massbalancemachine as mbm

    feat, ftype, seed_offset = task
    rng = np.random.default_rng(int(_PFI_G["seed"]) + int(seed_offset))

    df = _PFI_G["df_eval"].copy()
    if ftype == "monthly":
        if "MONTHS" not in df.columns:
            raise ValueError(
                "MONTHS column required for monthly feature permutation.")
        df[feat] = _permute_within_groups(df[feat].to_numpy(),
                                          df["MONTHS"].to_numpy(), rng)
    elif ftype == "static":
        idx = np.arange(len(df))
        rng.shuffle(idx)
        df[feat] = df[feat].to_numpy()[idx]
    else:
        raise ValueError("ftype must be 'monthly' or 'static'.")

    # Rebuild ds/loader for permuted df
    ds_p = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df,
        _PFI_G["MONTHLY_COLS"],
        _PFI_G["STATIC_COLS"],
        months_tail_pad=_PFI_G["months_tail_pad"],
        months_head_pad=_PFI_G["months_head_pad"],
        expect_target=True,
        show_progress=False)
    dl_p = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_p,
        _PFI_G["ds_train_copy"],
        seed=_PFI_G["seed"],
        batch_size=_PFI_G["batch_size"])
    with torch.no_grad():
        df_pred = _PFI_G["model"].predict_with_keys(_PFI_G["device"], dl_p,
                                                    ds_p)

    # Targets
    tcol = _PFI_G["target_col"]
    if tcol in df_pred.columns:
        y_true = df_pred[tcol].to_numpy()
    else:
        merged = df_pred.merge(df[["ID", tcol]].drop_duplicates("ID"),
                               on="ID",
                               how="left")
        if merged[tcol].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval contains per-ID targets."
            )
        y_true = merged[tcol].to_numpy()
    y_pred = df_pred["pred"].to_numpy()

    rmse = float(np.sqrt(np.mean((y_true - y_pred)**2)))
    delta = rmse - float(_PFI_G["baseline"])
    return feat, ftype, delta


# ------------------------------- user-facing function -------------------------------


def permutation_feature_importance_mbm_parallel(
    cfg,
    custom_params: Dict[str, Any],
    model_filename: str,
    df_eval: pd.DataFrame,
    MONTHLY_COLS: List[str],
    STATIC_COLS: List[str],
    ds_train,
    train_idx,
    target_col: str,
    months_head_pad: int,
    months_tail_pad: int,
    seed: int = 42,
    n_repeats: int = 5,
    batch_size: int = 256,
    max_workers: int = None,
) -> pd.DataFrame:
    """
    Parallel Permutation Feature Importance (CPU).
    Returns DataFrame: ['feature','type','mean_delta','std_delta','baseline','metric_name'].
    """

    # Build list of all tasks (feature x repeat)
    feats = [(f, "monthly") for f in MONTHLY_COLS] + [(f, "static")
                                                      for f in STATIC_COLS]
    tasks = []
    for feat, ftype in feats:
        for r in range(n_repeats):
            tasks.append((feat, ftype, r))

    # Use Linux fork so df_eval stays shared copy-on-write
    ctx = mp.get_context("fork")
    if max_workers is None:
        max_workers = min(max(1, (os.cpu_count() or 2) - 1), 32)

    # Run
    results = []
    with ProcessPoolExecutor(
            max_workers=max_workers,
            mp_context=ctx,
            initializer=_pfi_worker_init,
            initargs=(
                cfg,
                custom_params,
                model_filename,
                ds_train,
                train_idx,
                MONTHLY_COLS,
                STATIC_COLS,
                months_head_pad,
                months_tail_pad,
                df_eval,
                target_col,
                seed,
                batch_size,
            ),
    ) as ex:
        futs = [ex.submit(_pfi_worker_task, t) for t in tasks]
        for fut in tqdm(as_completed(futs),
                        total=len(futs),
                        desc=f"PFI (workers={max_workers})"):
            results.append(fut.result())

    # Aggregate
    rows = []
    baseline = _PFI_G.get(
        "baseline")  # won't be set in parent; recompute baseline here:
    # baseline computed inside workers, but not shared; recompute baseline serially in parent:
    # Minimal recompute on parent with single pass:

    import massbalancemachine as mbm
    _set_cpu_env_threads()
    _set_seeds(seed)
    ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
        ds_train)
    ds_train_copy.fit_scalers(train_idx)
    device = torch.device("cpu")
    model = mbm.models.LSTM_MB.build_model_from_params(cfg,
                                                       custom_params,
                                                       device,
                                                       verbose=False)
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state)
    model.eval()
    ds_eval_parent = mbm.data_processing.MBSequenceDataset.from_dataframe(
        df_eval,
        MONTHLY_COLS,
        STATIC_COLS,
        months_tail_pad=months_tail_pad,
        months_head_pad=months_head_pad,
        expect_target=True,
        show_progress=False)
    dl_eval_parent = mbm.data_processing.MBSequenceDataset.make_test_loader(
        ds_eval_parent, ds_train_copy, seed=seed, batch_size=batch_size)
    with torch.no_grad():
        df_pred_base = model.predict_with_keys(device, dl_eval_parent,
                                               ds_eval_parent)
    if target_col in df_pred_base.columns:
        y_true = df_pred_base[target_col].to_numpy()
    else:
        merged = df_pred_base.merge(df_eval[["ID", target_col
                                             ]].drop_duplicates("ID"),
                                    on="ID",
                                    how="left")
        if merged[target_col].isna().any():
            raise ValueError(
                "Missing targets after merge; ensure df_eval has per-ID targets."
            )
        y_true = merged[target_col].to_numpy()
    y_pred = df_pred_base["pred"].to_numpy()
    baseline = float(np.sqrt(np.mean((y_true - y_pred)**2)))

    # Build table
    import collections
    bucket: Dict[Tuple[str, str], List[float]] = collections.defaultdict(list)
    for feat, ftype, delta in results:
        bucket[(feat, ftype)].append(float(delta))

    for (feat, ftype), deltas in bucket.items():
        mu = float(np.mean(deltas))
        sd = float(np.std(deltas, ddof=1)) if len(deltas) > 1 else 0.0
        rows.append({
            "feature": feat,
            "type": ftype,
            "mean_delta": mu,
            "std_delta": sd
        })

    out = pd.DataFrame(rows).sort_values(
        "mean_delta", ascending=False).reset_index(drop=True)
    out["baseline"] = baseline
    out["metric_name"] = "rmse"
    return out

In [None]:
pfi_parallel = permutation_feature_importance_mbm_parallel(
    cfg=cfg,
    custom_params=custom_params,
    model_filename=model_filename,
    df_eval=df_train,  # your eval dataframe WITH TARGETS aligned to predictions
    MONTHLY_COLS=MONTHLY_COLS,
    STATIC_COLS=STATIC_COLS,
    ds_train=ds_train,
    train_idx=train_idx,
    target_col="POINT_BALANCE",  # <-- set your target column name
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    seed=cfg.seed,
    n_repeats=5,
    batch_size=256,
    max_workers=None,  # auto: n_cpus-1 (cap 32)
)

# Optional: quick plot
import matplotlib.pyplot as plt

plt.figure(figsize=(8, max(3, 0.35 * len(pfi_parallel))))
plt.barh(pfi_parallel["feature"],
         pfi_parallel["mean_delta"],
         xerr=pfi_parallel["std_delta"])
plt.gca().invert_yaxis()
plt.title(
    f"Permutation Feature Importance (Δ{pfi_parallel['metric_name'].iloc[0]}; baseline={pfi_parallel['baseline'].iloc[0]:.3f})"
)
plt.xlabel(
    f"Increase in {pfi_parallel['metric_name'].iloc[0]} (higher = more important)"
)
plt.tight_layout()
plt.show()