# Glacier grids from RGI:

Creates monthly grid files for the MBM to make PMB predictions over the whole glacier grid. The files come from the RGI grid with OGGM topography. Computing takes a long time because of the conversion to monthly format.
## Setting up:

In [None]:
import sys, os

sys.path.append(os.path.join(os.getcwd(),
                             '../../'))  # Add root of repo to import MBM
import csv
from functools import partial

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
import pickle
from collections import Counter
import ast
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
seed_all(cfg.seed)
free_up_cuda()  # in case no memory

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

# Climate columns
vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]
# Topographical columns
vois_topographical = [
    "aspect",
    "slope",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
    "topo",
]

glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)


In [None]:
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="11",
    rgi_version="6",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L3-L5_files/2023.1/elev_bands/W5E5_w_data/",
    log_level='WARNING',
    task_list=None,
)
# Save OGGM xr for all needed glaciers in RGI region 11.6:
export_oggm_grids(cfg, gdirs)

In [None]:
# RGI Ids:
# Read glacier ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
rgi_df.loc['rhone']

## Create RGI grids for all glaciers:

In [None]:
path_RGIs = cfg.dataPath + path_OGGM + 'xr_grids/'
glaciers = os.listdir(path_RGIs)

print(f"Found {len(glaciers)} glaciers in RGI region 11.6")

# Open an example
# rgi_gl = gdirs[0].rgi_id
rgi_gl = 'RGI60-11.01238'

ds = xr.open_dataset(path_RGIs + rgi_gl + '.zarr')
glacier_mask = np.where(ds['glacier_mask'].values == 0, np.nan,
                        ds['glacier_mask'].values)

# Create glacier mask
ds = ds.assign(masked_slope=glacier_mask * ds['slope'])
ds = ds.assign(masked_elev=glacier_mask * ds['topo'])
ds = ds.assign(masked_aspect=glacier_mask * ds['aspect'])
ds = ds.assign(masked_dis=glacier_mask * ds['dis_from_border'])

# Assign other variables only if available
if 'hugonnet_dhdt' in ds:
    ds = ds.assign(masked_hug=glacier_mask * ds['hugonnet_dhdt'])
if 'consensus_ice_thickness' in ds:
    ds = ds.assign(masked_cit=glacier_mask * ds['consensus_ice_thickness'])
if 'millan_v' in ds:
    ds = ds.assign(masked_miv=glacier_mask * ds['millan_v'])

glacier_indices = np.where(ds['glacier_mask'].values == 1)

fig, axs = plt.subplots(1, 4, figsize=(16, 8), sharey=True)

ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=False)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=False)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=False)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect OGGM")
axs[1].set_title("Slope OGGM")
axs[2].set_title("DEM OGGM")
axs[3].set_title("Glacier mask OGGM")

In [None]:
def create_masked_glacier(path_RGIs, rgi_gl):
    # Load dataset
    ds = xr.open_dataset(path_RGIs + rgi_gl + '.zarr')

    # Check if 'glacier_mask' exists
    if 'glacier_mask' not in ds:
        raise ValueError(
            f"'glacier_mask' variable not found in dataset {rgi_gl}")

    # Create glacier mask
    glacier_mask = np.where(ds['glacier_mask'].values == 0, np.nan,
                            ds['glacier_mask'].values)

    # Apply mask to core variables
    ds = ds.assign(masked_slope=glacier_mask * ds['slope'])
    ds = ds.assign(masked_elev=glacier_mask * ds['topo'])
    ds = ds.assign(masked_aspect=glacier_mask * ds['aspect'])
    ds = ds.assign(masked_dis=glacier_mask * ds['dis_from_border'])

    # Apply mask to optional variables if present
    if 'hugonnet_dhdt' in ds:
        ds = ds.assign(masked_hug=glacier_mask * ds['hugonnet_dhdt'])
    if 'consensus_ice_thickness' in ds:
        ds = ds.assign(masked_cit=glacier_mask * ds['consensus_ice_thickness'])
    if 'millan_v' in ds:
        ds = ds.assign(masked_miv=glacier_mask * ds['millan_v'])

    # Indices where glacier_mask == 1
    glacier_indices = np.where(ds['glacier_mask'].values == 1)

    return ds, glacier_indices

### Create masked grids:

In [None]:
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS/topo/RGI_v6_11/',
                             'xr_masked_grids/')
RUN = False
if RUN:
    emptyfolder(path_xr_grids)

    for gdir in tqdm(gdirs):
        rgi_gl = gdir.rgi_id

        try:
            # Create masked glacier dataset
            ds, glacier_indices = create_masked_glacier(path_RGIs, rgi_gl)
        except ValueError as e:
            print(f"Skipping {rgi_gl}: {e}")
            continue  # Skip to next glacier

        dx_m, dy_m = get_res_from_projected(ds)

        # Coarsen to 50 m resolution if needed
        if 20 < dx_m < 50:
            ds = coarsenDS_mercator(ds, target_res_m=50)
            dx_m, dy_m = get_res_from_projected(ds)
        else:
            ds = ds

        # Change coordinates to Lat/Lon projection
        original_proj = ds.pyproj_srs
        ds = ds.rio.write_crs(original_proj)
        ds_latlon = ds.rio.reproject("EPSG:4326")
        ds_latlon = ds_latlon.rename({'x': 'lon', 'y': 'lat'})

        # Save xarray dataset
        save_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
        ds_latlon.to_zarr(save_path)

# open example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

rgi_gl_rhone = gdir_rhone.rgi_id
ds = xr.open_dataset(path_xr_grids + rgi_gl_rhone + '.zarr')
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=True)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=True)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=True)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect")
axs[1].set_title("Slope")
axs[2].set_title("DEM")
axs[3].set_title("Glacier mask")
plt.tight_layout()

In [None]:
# open example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.00878':
        gdir_rhone = gdir

rgi_gl_rhone = gdir_rhone.rgi_id
ds = xr.open_dataset(path_xr_grids + rgi_gl_rhone + '.zarr')
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=True)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=True)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=True)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect")
axs[1].set_title("Slope")
axs[2].set_title("DEM")
axs[3].set_title("Glacier mask")
plt.tight_layout()

### Create monthly dataframes:

In [None]:
RUN = False
path_rgi_alps = os.path.join(cfg.dataPath,
                             'GLAMOS/topo/gridded_topo_inputs/RGI_v6_11/')

if RUN:
    years = range(2000, 2024)

    #os.makedirs(path_rgi_alps, exist_ok=True)
    #emptyfolder(path_rgi_alps)

    valid_rgis = [
        f.replace('.zarr', '') for f in os.listdir(path_xr_grids)
        if f.endswith('.zarr')
    ]

    processed_rgis = os.listdir(path_rgi_alps)
    rest_rgis = list(set(valid_rgis) - set(processed_rgis))
    print(f"Number of glaciers to process: {len(rest_rgis)}")

    for gdir in tqdm(gdirs, desc="Processing glaciers"):
        # for gdir in [gdir_rhone]:  # For testing, only process one glacier
        rgi_gl = gdir.rgi_id

        if rgi_gl not in valid_rgis:
            print(f"Skipping {rgi_gl}: not found in valid RGI glaciers")
            continue
        if rgi_gl in processed_rgis:
            continue
        try:
            file_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Missing file: {file_path}")

            try:
                ds = xr.open_zarr(file_path, consolidated=True)
            except Exception:
                ds = xr.open_zarr(file_path)

            # Create glacier grid
            try:
                df_grid = create_glacier_grid_RGI(ds, years, rgi_gl)
            except Exception as e:
                print(f"Failed creating glacier grid for {rgi_gl}: {e}")
                continue

            df_grid.reset_index(drop=True, inplace=True)

            # Add GLWD_ID
            df_grid['GLWD_ID'] = [
                mbm.data_processing.utils.get_hash(f"{r}_{y}") for r, y in zip(
                    df_grid['RGIId'].astype(str), df_grid['YEAR'].astype(str))
            ]
            df_grid['GLWD_ID'] = df_grid['GLWD_ID'].astype(str)
            df_grid['GLACIER'] = df_grid['RGIId']

            # Prepare output folder
            folder_path = os.path.join(path_rgi_alps, rgi_gl)
            os.makedirs(folder_path, exist_ok=True)

            # Process each year
            for year in years:
                try:
                    df_grid_y = df_grid[df_grid.YEAR == year].copy()
                    if df_grid_y.empty:
                        continue

                    # Wrap Dataset creation & climate feature extraction
                    try:
                        dataset_grid_yearly = mbm.data_processing.Dataset(
                            cfg=cfg,
                            data=df_grid_y,
                            region_name='CH',
                            data_path=os.path.join(cfg.dataPath,
                                                   path_PMB_GLAMOS_csv))

                        era5_climate_data = os.path.join(
                            cfg.dataPath, path_ERA5_raw,
                            'era5_monthly_averaged_data_Alps.nc')
                        geopotential_data = os.path.join(
                            cfg.dataPath, path_ERA5_raw,
                            'era5_geopotential_pressure_Alps.nc')

                        dataset_grid_yearly.get_climate_features(
                            climate_data=era5_climate_data,
                            geopotential_data=geopotential_data,
                            change_units=True,
                            smoothing_vois={
                                'vois_climate': vois_climate,
                                'vois_other': ['ALTITUDE_CLIMATE']
                            })
                    except Exception as e:
                        print(
                            f"Failed adding climate features for {rgi_gl}: {e}"
                        )
                        continue

                    vois_topographical_sub = [
                        voi for voi in vois_topographical
                        if voi in df_grid_y.columns
                    ]

                    dataset_grid_yearly.convert_to_monthly(
                        meta_data_columns=cfg.metaData,
                        vois_climate=vois_climate,
                        vois_topographical=vois_topographical_sub)

                    save_path = os.path.join(folder_path,
                                             f"{rgi_gl}_grid_{year}.parquet")
                    dataset_grid_yearly.data.to_parquet(save_path,
                                                        engine="pyarrow",
                                                        compression="snappy")
                    #print(f"Saved: {save_path}")

                except Exception as e:
                    print(f"Failed processing {rgi_gl} for year {year}: {e}")
                    continue

        except Exception as e:
            print(f"Error with glacier {rgi_gl}: {e}")
            continue

In [None]:
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

# Look at one example
# load the dataset
rgi_gl = gdir_rhone.rgi_id

year = 2000
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

year = 2008
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

In [None]:
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

# Look at one example
# load the dataset
year = 2008
rgi_gl = gdir_rhone.rgi_id

df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
voi = [
    't2m', 'tp', 'ALTITUDE_CLIMATE', 'ELEVATION_DIFFERENCE', 'hugonnet_dhdt',
    'consensus_ice_thickness'
]
axs = axs.flatten()
for i, var in enumerate(voi):
    sns.scatterplot(df,
                    x='POINT_LON',
                    y='POINT_LAT',
                    hue=var,
                    s=5,
                    alpha=0.5,
                    palette='twilight_shifted',
                    ax=axs[i])

### Location of all glaciers:

In [None]:
rgi_ids = os.listdir(path_rgi_alps)
pos_gl = []
for rgi_gl in tqdm(rgi_ids):
    df = pd.read_parquet(
        os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
    pos_gl.append((df.POINT_LAT.mean(), df.POINT_LON.mean()))
df_pos_all = pd.DataFrame(pos_gl, columns=['lat', 'lon'])
df_pos_all['rgi_id'] = rgi_ids

In [None]:
print('Number of glaciers in RGI region 11.6:', len(df_pos_all))

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

latN, latS = 48, 44
lonW, lonE = 4, 14
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

g = sns.scatterplot(
    data=df_pos_all,
    x='lon',
    y='lat',
    alpha=0.6,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

glacier_outline_rgi.plot(ax=ax2, transform=projPC, color='black')

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

## Train LSTM model:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

### Load trained model:

#### Simple model:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi',
    'slope_sgi',
]

feature_columns = MONTHLY_COLS + STATIC_COLS
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train_simple = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test_simple = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx_simple, val_idx_simple = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_simple), val_ratio=0.2, seed=cfg.seed)

custom_params = {
    'Fm': 8,
    'Fs': 2,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 0,
    'static_hidden': None,
    'static_dropout': None,
    'lr': 0.001,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'loss_spec': None
}

custom_params['two_heads'] = True
custom_params['head_dropout'] = 0.0

params_simple_model = custom_params.copy()

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_CA_simple.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train_simple) ---
ds_train_simple_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_simple)

ds_test_simple_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_simple)

train_dl, val_dl = ds_train_simple_copy.make_loaders(
    train_idx=train_idx_simple,
    val_idx=val_idx_simple,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test_simple and transforms it) ---
test_dl_simple = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_simple_copy, ds_train_simple_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)

# Evaluate on test
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl_simple, ds_test_simple_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

#### Full model (with OGGM variables):

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]

STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
    'millan_v'
]

feature_columns = MONTHLY_COLS + STATIC_COLS
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train_full = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test_full = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx_full, val_idx_full = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train_full), val_ratio=0.2, seed=cfg.seed)

custom_params = {
    'Fm': 8,
    'Fs': 5,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 0,
    'static_hidden': None,
    'static_dropout': None,
    'lr': 0.001,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'loss_spec': None
}

custom_params['two_heads'] = True
custom_params['head_dropout'] = 0.0

params_full_model = custom_params.copy()

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_CA_full.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train_full) ---
ds_train_full_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train_full)

ds_test_full_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test_full)

train_dl, val_dl = ds_train_full_copy.make_loaders(
    train_idx=train_idx_full,
    val_idx=val_idx_full,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test_full and transforms it) ---
test_dl_full = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_full_copy, ds_train_full_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = False
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)

# Evaluate on test
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl_full, ds_test_full_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

### Extrapolate in space to RGI glaciers:

In [None]:
rgi_id_list = os.listdir(path_rgi_alps)

MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'ELEVATION_DIFFERENCE',
]

# Define paths
path_save_glw = os.path.join(
    cfg.dataPath, 'GLAMOS/distributed_MB_grids/MBM/central_europe/')
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS/topo/RGI_v6_11/',
                             'xr_masked_grids/')

# Required by the dataset builder regardless of your feature list
REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']

# Paths to your trained weights
MODEL_SIMPLE_PATH = "models/lstm_model_2025-09-26_CA_simple.pt"
MODEL_FULL_PATH = "models/lstm_model_2025-09-26_CA_full.pt"

# simple cache keyed by kind -> model instance
_MODEL_CACHE = {}

def get_model(kind: str, cfg, device, params_simple_model: dict,
              params_full_model: dict):
    """
    Build and load the LSTM model for the requested kind ('simple'|'full'),
    using the corresponding hyperparameter dict and checkpoint.

    Parameters
    ----------
    kind : {'simple','full'}
    cfg : your MBM config
    device : torch.device
    params_simple_model : dict
    params_full_model   : dict

    Returns
    -------
    model : torch.nn.Module (eval mode)
    """
    kind = kind.lower()
    if kind not in ("simple", "full"):
        raise ValueError(f"Unknown model kind: {kind}")

    if kind in _MODEL_CACHE:
        return _MODEL_CACHE[kind]

    # pick params + checkpoint
    if kind == "simple":
        params = params_simple_model
        ckpt_path = MODEL_SIMPLE_PATH
    else:
        params = params_full_model
        ckpt_path = MODEL_FULL_PATH

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Model checkpoint not found: {ckpt_path}")

    # build with matching params
    model = mbm.models.LSTM_MB.build_model_from_params(cfg, params, device, verbose = False)

    # load weights
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    model.eval()

    _MODEL_CACHE[kind] = model
    return model


# Safe rename helper
def safe_rename(df, mapping):
    present = {k: v for k, v in mapping.items() if k in df.columns}
    return df.rename(columns=present) if present else df


# Robust year regex: RGI60-11.01238_grid_2003.parquet
YEAR_RE = re.compile(r"_grid_(\d{4})\.parquet$")

RUN = True
if RUN:
    emptyfolder(path_save_glw)

    # CSV header
    output_file = os.path.join("logs/glacier_mean_MB.csv")
    with open(output_file, 'w') as f:
        f.write("Index,RGIId,Year,Mean_MB\n")

    index_counter = 0

    for rgi_gl in tqdm(rgi_id_list):
        seed_all(cfg.seed)
        glacier_path = os.path.join(path_rgi_alps, rgi_gl)
        if not os.path.exists(glacier_path):
            print(f"Folder not found for {rgi_gl}, skipping...")
            continue

        # Collect years robustly
        years = []
        for fname in os.listdir(glacier_path):
            if not (fname.endswith(".parquet") and rgi_gl in fname):
                continue
            m = YEAR_RE.search(fname)
            if m:
                years.append(int(m.group(1)))
        years = sorted(set(years))
        if not years:
            print(f"No parquet years for {rgi_gl}, skipping...")
            continue

        # Optionally clear model cache per glacier
        _MODEL_CACHE.clear()

        for year in years:
            file_name = f"{rgi_gl}_grid_{year}.parquet"
            file_path = os.path.join(glacier_path, file_name)

            try:
                # Load parquet
                df_grid_monthly = pd.read_parquet(file_path)
            except Exception as e:
                print(f"[{rgi_gl} {year}] read error: {e}")
                continue

            df_grid_monthly.drop_duplicates(inplace=True)

            # safe rename if present
            df_grid_monthly = safe_rename(df_grid_monthly, {
                'aspect': 'aspect_sgi',
                'slope': 'slope_sgi'
            })

            oggm_columns = [
                'hugonnet_dhdt', 'consensus_ice_thickness', 'millan_v'
            ]
            has_oggm = all(c in df_grid_monthly.columns for c in oggm_columns)

            if has_oggm:
                kind = "full"
                STATIC_COLS = ['aspect_sgi', 'slope_sgi'] + oggm_columns
                ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
                    ds_train_full)
                ds_train_copy.fit_scalers(train_idx_full)
            else:
                kind = "simple"
                STATIC_COLS = ['aspect_sgi', 'slope_sgi']
                ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
                    ds_train_simple)
                ds_train_copy.fit_scalers(train_idx_simple)

            # Build column list in correct schema
            feature_columns = MONTHLY_COLS + STATIC_COLS
            all_columns = feature_columns + cfg.fieldsNotFeatures

            # Keep required + feature columns
            keep = [
                c for c in (set(all_columns) | set(REQUIRED))
                if c in df_grid_monthly.columns
            ]
            df_grid_monthly = df_grid_monthly[keep]

            # Ensure required columns exist
            missing_req = [
                c for c in REQUIRED if c not in df_grid_monthly.columns
            ]
            if missing_req:
                print(
                    f"[{rgi_gl} {year}] missing required cols: {missing_req}, skipping..."
                )
                continue

            # Minimal NaN clean-up
            df_grid_monthly_a = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])
            if df_grid_monthly_a.empty:
                print(
                    f"[{rgi_gl} {year}] empty after NaN cleanup, skipping...")
                continue

            # Build inference dataset
            ds_gl_a = mbm.data_processing.MBSequenceDataset.from_dataframe(
                df_grid_monthly_a,
                MONTHLY_COLS,
                STATIC_COLS,
                months_tail_pad=months_tail_pad,
                months_head_pad=months_head_pad,
                expect_target=False,
                show_progress=False,
            )
            test_gl_dl_a = mbm.data_processing.MBSequenceDataset.make_test_loader(
                ds_gl_a, ds_train_copy, seed=cfg.seed, batch_size=128)

            # Load model (cached) and predict
            try:
                model = get_model(kind, cfg, device, params_simple_model,
                                  params_full_model)
            except Exception as e:
                print(f"[{rgi_gl} {year}] model load error: {e}")
                continue

            df_preds_a = model.predict_with_keys(device, test_gl_dl_a, ds_gl_a)

            # Join preds for output
            data_a = df_preds_a[['ID', 'pred']].set_index('ID')
            meta_cols = [
                c for c in ['YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID']
                if c in df_grid_monthly_a.columns
            ]
            grouped_ids_a = (
                df_grid_monthly_a.groupby('ID')[meta_cols].first().merge(
                    data_a, left_index=True, right_index=True, how='left'))
            months_per_id_a = df_grid_monthly_a.groupby(
                'ID')['MONTHS'].unique()
            grouped_ids_a = grouped_ids_a.merge(months_per_id_a,
                                                left_index=True,
                                                right_index=True)

            grouped_ids_a.reset_index(inplace=True)
            grouped_ids_a.sort_values(by='ID', inplace=True)

            pred_y_annual = grouped_ids_a.copy()
            pred_y_annual['PERIOD'] = 'annual'
            pred_y_annual = pred_y_annual.drop(columns=['YEAR'],
                                               errors='ignore')

            mean_MB = pred_y_annual['pred'].mean()

            with open(output_file, 'a') as f:
                f.write(f"{index_counter},{rgi_gl},{year},{mean_MB:.4f}\n")

            index_counter += 1

In [None]:
csv_path = os.path.join("logs/glacier_mean_MB.csv")
df = pd.read_csv(csv_path)
df = df[df.RGIId == 'RGI60-11.01238']
# ensure Year is integer (just in case)
df["Year"] = df["Year"].astype(int)

# add the new column
df["LSTM_full"] = np.nan

# set Year as index
df.set_index("Year", inplace=True)

# your values
vals = [
    -0.764493, -0.85534433, -0.61580243, -0.6876, -1.81396141, -1.17730185,
    -0.28465775, -0.41800559, -0.96153416, -0.74692587, -1.41592501,
    -0.97360032, -0.97257272, -1.23755494, -0.51568465, -3.05461539,
    -2.76094933
]

# build matching years dynamically (starts at 2007, matches length of vals)
years = list(range(2007, 2007 + len(vals)))  # 2007..2023 inclusive

# assign
df.loc[years, "LSTM_full"] = vals

# optional: plot
ax = df.plot(y=["Mean_MB", "LSTM_full"], marker="o")
ax.set_xlabel("Year")
ax.set_ylabel("Mass balance (m w.e.)")
plt.tight_layout()
plt.show()


In [None]:
df

In [None]:
output_file

### Mean predicted MB:

In [None]:
# open output file
output_df = os.path.join(path_save_glw, "glacier_mean_MB.csv")
output_df = pd.read_csv(output_df)

output_df.groupby('Year').agg({'Mean_MB': 'sum'})