## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.utils import *
from scripts.glamos import *
from scripts.models import *
from scripts.geo_data import *
from scripts.dataset import *
from scripts.geodetic import *
from scripts.plotting import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
use_mbm_style()

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Input data:
### Input dataset:

In [None]:
# Read GLAMOS stake data
data_glamos = get_stakes_data(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    df=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly = data_monthly[data_monthly['YEAR']
                            < 2025]  # Used elsewhere for validation

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"

existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

In [None]:
# Convert to start of August instead:
# Convert to str → parse → replace month/day → convert back to int
data_glamos_Aug_ = data_glamos.copy()
data_glamos_Aug_["FROM_DATE"] = (
    data_glamos_Aug_["FROM_DATE"].astype(str).str.slice(0,
                                                        4)  # extract year YYYY
    .astype(int).astype(str) + "0801"  # append "0801"
).astype(int)

# Same for full temporal resolution (run or load data):
# Compute padding for monthly data
months_head_pad_Aug_, months_tail_pad_Aug_ = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos_Aug_)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

RUN = False
data_monthly_Aug_ = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos_Aug_,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM_Aug_.csv')

# Create DataLoader
dataloader_gl_Aug_ = mbm.dataloader.DataLoader(cfg,
                                               data=data_monthly_Aug_,
                                               random_seed=cfg.seed,
                                               meta_data_columns=cfg.metaData)

# Remove 2025
data_monthly_Aug_ = data_monthly_Aug_[data_monthly_Aug_['YEAR']
                                      < 2025]  # Used elsewhere for validation

# Blocking on glaciers:
# Model is trained on all glaciers --> "Within sample"

existing_glaciers = set(data_monthly_Aug_.GLACIER.unique())
train_glaciers = existing_glaciers
data_train_Aug_ = data_monthly_Aug_[data_monthly_Aug_.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train_Aug_))

# Validation and train split:
data_train_Aug_ = data_train_Aug_
data_train_Aug_['y'] = data_train_Aug_['POINT_BALANCE']

## Blocking on glaciers:

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = ['aspect_sgi', 'slope_sgi', 'svf']

feature_columns = MONTHLY_COLS + STATIC_COLS

#### Build sequences:

In [None]:
seed_all(cfg.seed)

ds_train = build_combined_LSTM_dataset(df_loss=data_train,
                                       df_full=data_train_Aug_,
                                       monthly_cols=MONTHLY_COLS,
                                       static_cols=STATIC_COLS,
                                       months_head_pad=months_head_pad_Aug_,
                                       months_tail_pad=months_tail_pad_Aug_,
                                       normalize_target=True,
                                       expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

In [None]:
month_list, month_pos = mbm.data_processing.utils._rebuild_month_index(
    months_head_pad_Aug_, months_tail_pad_Aug_)
month_order = [m for m, _ in sorted(month_pos.items(), key=lambda x: x[1])]
print("Month order used in sequences:", month_order)
inspect_LSTM_sample(ds_train, 0, month_labels=month_order)
inspect_LSTM_sample(ds_train, 10, month_labels=month_order)
inspect_LSTM_sample(ds_train, 150, month_labels=month_order)

In [None]:
inspect_LSTM_padded_months(ds_train)

## Grid search:

In [None]:
from itertools import product
from copy import deepcopy


def generate_param_sets(const_params, param_grid):
    """
    Generate all valid parameter combinations.
    Returns: List[dict]
    """
    keys = list(param_grid.keys())
    values = [param_grid[k] for k in keys]

    param_sets = []

    for combo in product(*values):
        params = dict(zip(keys, combo))

        # ---- conditional logic ----

        # 1) If no static MLP, ignore static_hidden & static_dropout
        if params["static_layers"] == 0:
            params["static_hidden"] = None
            params["static_dropout"] = None

        # 2) LSTM dropout only relevant if stacked
        if params["num_layers"] == 1:
            params["dropout"] = 0.0

        # 3) Head dropout sanity
        if params["head_dropout"] < 0.0:
            continue

        # Merge constants
        full_params = {**const_params, **params}
        param_sets.append(full_params)

    return param_sets

In [None]:
def sample_param_sets(
    const_params,
    param_grid,
    n_samples: int,
    seed: int = cfg.seed,
):
    """
    Generate all valid parameter combinations, then randomly sample n_samples.
    """
    all_params = generate_param_sets(const_params, param_grid)

    if n_samples >= len(all_params):
        print(f"Requested {n_samples}, but only {len(all_params)} available. Using all.")
        return all_params

    rng = random.Random(seed)
    return rng.sample(all_params, n_samples)

In [None]:
const_params = {
    "Fm": ds_train.Xm.shape[-1],
    "Fs": ds_train.Xs.shape[-1],
    "bidirectional": False,
    "two_heads": False,
    "loss_name": "neutral",
    "loss_spec": None,
}

param_grid = {
    # ----- LSTM -----
    "hidden_size": [64, 96, 128, 160],
    "num_layers": [1, 2],
    "dropout": [0.1, 0.2, 0.3],

    # ----- static MLP -----
    "static_layers": [0, 1, 2],
    "static_hidden": [32, 64, 128],
    "static_dropout": [0.1, 0.2, 0.3],

    # ----- head -----
    "head_dropout": [0.0, 0.05, 0.1],

    # ----- optimization -----
    "lr": [5e-4, 1e-3, 2e-3],
    "weight_decay": [1e-5, 1e-4],
}

# sampled_params = generate_param_sets(const_params, param_grid)
# print(f"Total runs: {len(sampled_params)}")

N_SAMPLES = 300  # or 50, 200, etc.

sampled_params = sample_param_sets(
    const_params,
    param_grid,
    n_samples=N_SAMPLES,
    seed=cfg.seed,
)

print(f"Running {len(sampled_params)} random configurations")

In [None]:
import csv
import hashlib, json
RUN = True
if RUN:
    os.makedirs("logs", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    log_filename = f'logs/lstm_param_search_progress_IS_past_{datetime.now().strftime("%Y-%m-%d")}.csv'

    # create log with header
    with open(log_filename, mode='w', newline='') as log_file:
        writer = csv.DictWriter(log_file,
                                fieldnames=list(sampled_params[0].keys()) +
                                ['valid_loss', 'test_rmse_a', 'test_rmse_w'])
        writer.writeheader()

    results = []
    best_overall = {"val": float('inf'), "row": None, "params": None}

    for i, params in enumerate(sampled_params):
        seed_all(cfg.seed)

        def param_hash(params):
            s = json.dumps(params, sort_keys=True)
            return hashlib.md5(s.encode()).hexdigest()[:8]

        run_id = param_hash(params)
        model_filename = f"models/GS_past/best_lstm_IS_{run_id}.pt"

        # delete existing model file:
        if os.path.exists(model_filename):
            os.remove(model_filename)
            print(f"Deleted existing model file: {model_filename}")

        # --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
        ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_train)

        ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_train)

        train_dl, val_dl = ds_train_copy.make_loaders(
            train_idx=train_idx,
            val_idx=val_idx,
            batch_size_train=64,
            batch_size_val=128,
            seed=cfg.seed,
            fit_and_transform=
            True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
            shuffle_train=True,
            use_weighted_sampler=True  # use weighted sampler for training
        )

        # --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
        test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
            ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

        print(f"\n--- Running config {i+1}/{len(sampled_params)} ---")
        print(params)

        # Build model
        model = mbm.models.LSTM_MB.build_model_from_params(cfg, params, device)

        # Choose loss
        loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(params)

        # Train
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=train_dl,
            val_dl=val_dl,
            epochs=150,
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            clip_val=1,
            # scheduler
            sched_factor=0.5,
            sched_patience=6,
            sched_threshold=0.01,
            sched_threshold_mode="rel",
            sched_cooldown=1,
            sched_min_lr=1e-6,
            # early stopping
            es_patience=15,
            es_min_delta=1e-4,
            # logging
            log_every=5,
            verbose=True,
            # checkpoint
            save_best_path=model_filename,
            loss_fn=loss_fn,
        )

        # Load the best weights
        best_state = torch.load(model_filename, map_location=device)
        model.load_state_dict(best_state)
        test_metrics, test_df_preds = model.evaluate_with_preds(
            device, test_dl, ds_test_copy)
        test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
            'RMSE_winter']

        # Log row
        row = {
            **params, 'valid_loss': float(best_val),
            'test_rmse_a': float(test_rmse_a),
            'test_rmse_w': float(test_rmse_w)
        }

        print(test_rmse_a, test_rmse_w)

        with open(log_filename, mode='a', newline='') as log_file:
            writer = csv.DictWriter(log_file, fieldnames=list(row.keys()))
            writer.writerow(row)

        # Track best by validation loss
        if best_val < best_overall['val']:
            best_overall = {"val": best_val, "row": row, "params": params}

    print("\n=== Best config by validation loss ===")
    print(best_overall['params'])
    print(best_overall['row'])