## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:
### Input dataset:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = ["aspect_sgi", "slope_sgi", "svf"]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_gs_no_oggm_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

data_test_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    TEST_GLACIERS)]
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]

splits_Aug_, test_set_Aug_, train_set_Aug_ = get_CV_splits(
    dataloader_gl_Aug_,
    test_split_on='GLACIER',
    test_splits=TEST_GLACIERS,
    random_state=cfg.seed)

# # Validation and train split:
data_train_Aug_ = train_set_Aug_['df_X']
data_train_Aug_['y'] = train_set_Aug_['y']
data_test_Aug_ = test_set_Aug_['df_X']
data_test_Aug_['y'] = test_set_Aug_['y']

seed_all(cfg.seed)

df_train_Aug_ = data_train_Aug_.copy()
df_train_Aug_['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test_Aug_ = data_test_Aug_.copy()
df_test_Aug_['PERIOD'] = df_test_Aug_['PERIOD'].str.strip().str.lower()

## Blocking on glaciers:

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

#### Build sequences:

In [None]:
ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=True,
                                       expect_target=True)

ds_test = build_combined_LSTM_dataset(df_loss=data_test,
                                      df_full=data_test_Aug_,
                                      monthly_cols=MONTHLY_COLS,
                                      static_cols=STATIC_COLS,
                                      months_head_pad=months_head_pad_Aug_,
                                      months_tail_pad=months_tail_pad_Aug_,
                                      normalize_target=True,
                                      expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

In [None]:
month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    months_head_pad_Aug_, months_tail_pad_Aug_)
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)

inspect_LSTM_sample(ds_train, 0, month_labels=month_order)
inspect_LSTM_sample(ds_train, 10, month_labels=month_order)
inspect_LSTM_sample(ds_train, 150, month_labels=month_order)

In [None]:
def safe_item(x):
    return x.item() if x is not None else None


print("Train dataset (after make_loaders):")
print(f"  normalize_target = {ds_train.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_train.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_train.y_std)}")
print(f"  Actual y.mean()  = {ds_train.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_train.y.std().item():.4f}")

print("\nTest dataset (after make_test_loader):")
print(f"  normalize_target = {ds_test.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_test.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_test.y_std)}")
print(f"  Actual y.mean()  = {ds_test.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_test.y.std().item():.4f}")

## Grid search:

In [None]:
from itertools import product

param_grid = {
    "lr": [1e-3, 5e-4, 1e-4],
    "weight_decay": [0.0, 1e-5, 1e-4],
    "hidden_size": [64, 128],
    "num_layers": [1, 2],
    "dropout": [0.1, 0.2],  # force some dropout
    "head_dropout": [0, 0.1],
}

static = [
    #(0, 0, None),  # identity (use 0 here for robustness)
    (2, [128, 64], 0.1),  # small two-layer MLP
]


def pack(static_triplet):
    sl, sh, sd = static_triplet
    return dict(
        static_layers=sl,
        static_hidden=sh,
        static_dropout=sd,
    )


# ---- constants that should be the same for every sample ----
const_params = {
    "Fm": ds_train.Xm.shape[-1],  # monthly features
    "Fs": ds_train.Xs.shape[-1],  # static features
    "bidirectional": False,
    "loss_name": "neutral",
    "loss_spec": None,
    "two_heads": False,
}


def grid_iter_with_static_and_const(grid, static_list, const):
    keys = list(grid.keys())
    for values in product(*(grid[k] for k in keys), static_list):
        params = dict(zip(keys, values[:-1]))  # non-static hyperparams
        params.update(pack(values[-1]))  # add static config
        params.update(const)  # add constants
        yield params


# ---- generate all sampled param sets ----
sampled_params = list(
    grid_iter_with_static_and_const(param_grid, static, const_params))
print(len(sampled_params))
print(sampled_params[0])  # preview one combo

In [None]:
import csv

RUN = True
if RUN:
    os.makedirs("logs", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    log_filename = f'logs/lstm_one_heads_param_search_progress_no_oggm_{datetime.now().strftime("%Y-%m-%d")}.csv'

    # create log with header
    with open(log_filename, mode='w', newline='') as log_file:
        writer = csv.DictWriter(log_file,
                                fieldnames=list(sampled_params[0].keys()) +
                                ['valid_loss', 'test_rmse_a', 'test_rmse_w'])
        writer.writeheader()

    results = []
    best_overall = {"val": float('inf'), "row": None, "params": None}


    for i, params in enumerate(sampled_params):
        seed_all(cfg.seed)
        model_filename = 'models/best_lstm_mb_gs_one_heads_svf_pcsr_OOS.pt'

        # delete existing model file:
        if os.path.exists(model_filename):
            os.remove(model_filename)
            print(f"Deleted existing model file: {model_filename}")

        # --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
        seed_all(cfg.seed)
        ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_train)
        ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_test)

        train_dl, val_dl = ds_train_copy.make_loaders(
            train_idx=train_idx,
            val_idx=val_idx,
            batch_size_train=64,
            batch_size_val=128,
            seed=cfg.seed,
            fit_and_transform=
            True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
            shuffle_train=True,
            use_weighted_sampler=True  # use weighted sampler for training
        )

        # --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
        test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
            ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

        print(f"\n--- Running config {i+1}/{len(sampled_params)} ---")
        print(params)

        # Build model
        seed_all(cfg.seed)
        model = mbm.models.LSTM_MB.build_model_from_params(cfg, params, device)
        # Choose loss
        loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(params)

        # Train
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=train_dl,
            val_dl=val_dl,
            epochs=150,
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            clip_val=1,
            # scheduler
            sched_factor=0.5,
            sched_patience=6,
            sched_threshold=0.01,
            sched_threshold_mode="rel",
            sched_cooldown=1,
            sched_min_lr=1e-6,
            # early stopping
            es_patience=15,
            es_min_delta=1e-4,
            # logging
            log_every=5,
            verbose=True,
            # checkpoint
            save_best_path=model_filename,
            loss_fn=loss_fn,
        )

        # Load the best weights
        best_state = torch.load(model_filename, map_location=device)
        model.load_state_dict(best_state)
        test_metrics, test_df_preds = model.evaluate_with_preds(
            device, test_dl, ds_test_copy)
        test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
            'RMSE_winter']

        # Log row
        row = {
            **params, 'valid_loss': float(best_val),
            'test_rmse_a': float(test_rmse_a),
            'test_rmse_w': float(test_rmse_w)
        }

        print(test_rmse_a, test_rmse_w)

        with open(log_filename, mode='a', newline='') as log_file:
            writer = csv.DictWriter(log_file, fieldnames=list(row.keys()))
            writer.writerow(row)

        results.append(row)

        # Track best by validation loss
        if best_val < best_overall['val']:
            best_overall = {"val": best_val, "row": row, "params": params}

    print("\n=== Best config by validation loss ===")
    print(best_overall['params'])
    print(best_overall['row'])