## Setting Up:

In [None]:
# --- Standard library
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import redirect_stdout
from datetime import datetime
import io
import logging
import multiprocessing as mp
import os
import sys
import warnings

# Make repo root importable (for MBM & scripts/*)
sys.path.append(os.path.join(os.getcwd(), '../../'))

# --- Third-party
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cmcrameri import cm
import torch
from tqdm.auto import tqdm
import xarray as xr
from matplotlib.lines import Line2D

import massbalancemachine as mbm

# --- Project-local
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

# --- Notebook settings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

## Input data:

In [None]:
vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi", "slope_sgi", "hugonnet_dhdt", "consensus_ice_thickness",
    "millan_v", "svf"
]

# Read GLAMOS stake data
data_glamos = getStakesData(cfg)

# Compute padding for monthly data
months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = True
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM_svf_IS.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

### Blocking on glaciers:

Model is trained on all glaciers --> "Within sample"


In [None]:
existing_glaciers = set(data_monthly.GLACIER.unique())
train_glaciers = existing_glaciers
data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

# Validation and train split:
data_train = data_train
data_train['y'] = data_train['POINT_BALANCE']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'ELEVATION_DIFFERENCE',
    'pcsr'
]
STATIC_COLS = [
    'aspect_sgi',
    'slope_sgi',
    "svf",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True,
    normalize_target=False)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

# Look at the padding for one example
key = ('adler', 2009, 11, 'winter')

# find the index of this key
try:
    idx = ds_train.keys.index(key)
except ValueError:
    raise ValueError(f"Key {key} not found in dataset.")

# fetch the corresponding sequence
sequence = ds_train[idx]
sequence['mv'], sequence['mw'], sequence['ma']

### Define & train model:

In [None]:
log_path = 'logs/lstm_two_heads_param_search_progress_no_oggm_IS_2025-10-22.csv'
best_params = get_best_params_for_lstm(log_path, select_by='test_rmse_a')
df = pd.read_csv(log_path)
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="avg_test_loss", inplace=True)
print(best_params)
df.head(10)

In [None]:
# custom_params = {
#     'Fm': 8,
#     'Fs': 6,
#     'hidden_size': 128,
#     'num_layers': 2,
#     'bidirectional': False,
#     'dropout': 0.1,
#     'static_layers': 2,
#     'static_hidden': [128, 64],
#     'static_dropout': 0.1,
#     'lr': 0.001,
#     'weight_decay': 0.0,
#     'loss_name': 'neutral',
#     'two_heads': True,
#     'head_dropout': 0.0,
#     'loss_spec': None
# }

custom_params = best_params
custom_params['Fm'] = 9
custom_params['Fs'] = 6
custom_params['two_heads'] = False


# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_with_oggm_IS_one_head.pt"

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=custom_params['lr'],
        weight_decay=custom_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )
    plot_history_lstm(history)

#model_filename = f"models/lstm_model_2025-10-31_no_oggm_IS.pt"

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# Evaluate on test
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
def safe_item(x):
    return x.item() if x is not None else None


print("Train dataset (after make_loaders):")
print(f"  normalize_target = {ds_train_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_train_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_train_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_train_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_train_copy.y.std().item():.4f}")

print("\nTest dataset (after make_test_loader):")
print(f"  normalize_target = {ds_test_copy.normalize_target}")
print(f"  y_mean (scaler)  = {safe_item(ds_test_copy.y_mean)}")
print(f"  y_std  (scaler)  = {safe_item(ds_test_copy.y_std)}")
print(f"  Actual y.mean()  = {ds_test_copy.y.mean().item():.4f}")
print(f"  Actual y.std()   = {ds_test_copy.y.std().item():.4f}")

In [None]:
y_train = ds_train_copy.y.cpu().numpy()
y_test = ds_test_copy.y.cpu().numpy()

plt.figure(figsize=(6, 4))
plt.hist(y_train, bins=30, alpha=0.6, label="Train", density=True)
plt.hist(y_test, bins=30, alpha=0.6, label="Test", density=True)
plt.axvline(y_train.mean(),
            color='k',
            linestyle='--',
            lw=1,
            label='mean (train)')
plt.xlabel("Target (y)")
plt.ylabel("Density")
plt.title(
    f"Target distribution ({'normalized' if ds_train_copy.normalize_target else 'physical'} units)"
)
plt.legend()
plt.show()

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_area = get_gl_area(cfg)
gl_area["clariden"] = gl_area["claridenL"]

gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
test_df_preds['gl_elv'] = test_df_preds['GLACIER'].map(gl_per_el)

train_glaciers = {
    'adler', 'albigna', 'aletsch', 'allalin', 'basodino', 'clariden',
    'corbassiere', 'corvatsch', 'findelen', 'forno', 'gietro', 'gorner',
    'gries', 'hohlaub', 'joeri', 'limmern', 'morteratsch', 'murtel', 'oberaar',
    'otemma', 'pizol', 'plattalva', 'rhone', 'sanktanna', 'schwarzberg',
    'sexrouge', 'silvretta', 'tortin', 'tsanfleuron'
}
test_gl_per_el = gl_per_el[list(train_glaciers)].sort_values().index

fig, axs = plt.subplots(7, 3, figsize=(25, 30), sharex=False)

axs = PlotIndividualGlacierPredVsTruth(test_df_preds,
                                       axs=axs,
                                       color_annual=color_dark_blue,
                                       color_winter=color_pink,
                                       custom_order=test_gl_per_el,
                                       add_text=True,
                                       ax_xlim=None,
                                       gl_area=gl_area)


## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
from scripts.parallel_mb import MBJobConfig, run_glacier_mb

path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/LSTM_with_oggm_IS_original_y_one_head')
RUN = True
if RUN:
    job = MBJobConfig(
        cfg=cfg,
        MONTHLY_COLS=MONTHLY_COLS,
        STATIC_COLS=STATIC_COLS,
        fields_not_features=cfg.fieldsNotFeatures,
        model_filename=model_filename,
        custom_params=custom_params,
        ds_train=ds_train,
        train_idx=train_idx,
        months_head_pad=months_head_pad,
        months_tail_pad=months_tail_pad,
        data_path=cfg.dataPath,
        path_glacier_grid_glamos=path_glacier_grid_glamos,
        path_xr_grids=os.path.join(cfg.dataPath, 'GLAMOS', 'topo',
                                   'GLAMOS_DEM', 'xr_masked_grids'),
        path_save_glw=path_save_glw,
        seed=cfg.seed,
        max_workers=20,  # or an int
        cpu_only=True,
        ONLY_GEODETIC=True,
        denorm=ds_train_copy.normalize_target,
        save_monthly=True)

    # 3) Run
    summary = run_glacier_mb(job, glacier_list, periods_per_glacier)
    print("SUMMARY:", summary)

In [None]:
hydro_months = [
    'oct',
    'nov',
    'dec',
    'jan',
    'feb',
    'mar',
    'apr',
    'may',
    'jun',
    'jul',
    'aug',
    'sep',
]

glaciers = os.listdir(path_save_glw)

# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(path_save_glw, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_LSTM = pd.concat(all_glacier_data, axis=0, ignore_index=True)

In [None]:
def pick_file_glamos(glacier, year, period="winter"):
    suffix = "ann" if period == "annual" else "win"
    base = os.path.join(cfg.dataPath, path_distributed_MB_glamos, "GLAMOS",
                        glacier)
    cand_lv95 = os.path.join(base, f"{year}_{suffix}_fix_lv95.grid")
    cand_lv03 = os.path.join(base, f"{year}_{suffix}_fix_lv03.grid")
    if os.path.exists(cand_lv95):
        return cand_lv95, "lv95"
    if os.path.exists(cand_lv03):
        return cand_lv03, "lv03"
    return None, None


def load_glamos_wgs84(glacier, year, period):
    """Load one GLAMOS .grid file and return it as an xarray in WGS84."""
    path, cs = pick_file_glamos(glacier, year, period)
    if path is None:
        return None
    meta, arr = load_grid_file(path)
    da = convert_to_xarray_geodata(arr, meta)
    if cs == "lv03":
        return transform_xarray_coords_lv03_to_wgs84(da)
    elif cs == "lv95":
        return transform_xarray_coords_lv95_to_wgs84(da)
    else:
        return None


def load_all_glamos(glacier_years, path_glamos):
    """
    Loads both annual and winter GLAMOS grids for all glaciers and years,
    interpolates DEM elevation onto MB grids, and returns two DataFrames:
    df_GLAMOS_w, df_GLAMOS_a.
    """
    all_glacier_data_w, all_glacier_data_a = [], []

    for glacier_name in tqdm(glacier_years.keys(), desc="Loading GLAMOS data"):
        glacier_path = os.path.join(path_glamos, glacier_name)
        if not os.path.isdir(glacier_path):
            continue

        years = glacier_years[glacier_name]
        all_years_w, all_years_a = [], []

        for year in years:
            # --- Load DEM ---
            dem_path = (cfg.dataPath + path_GLAMOS_topo +
                        f"xr_masked_grids/{glacier_name}_{year}.zarr")
            if not os.path.exists(dem_path):
                continue
            ds_dem = xr.open_zarr(dem_path)

            # --- Load Winter MB ---
            ds_w = load_glamos_wgs84(glacier_name, year, period="winter")
            if ds_w is not None:
                masked_elev_interp = ds_dem["masked_elev"].interp_like(
                    ds_w, method="nearest")
                masked_elev_interp = masked_elev_interp.assign_coords(x=ds_w.x,
                                                                      y=ds_w.y)

                ds_merged_w = xr.merge(
                    [
                        ds_w.to_dataset(name="mb"),
                        masked_elev_interp.to_dataset(name="masked_elev"),
                    ],
                    compat="override",
                )

                df_w = ds_merged_w.to_dataframe().reset_index()
                df_w = df_w[df_w["mb"].notna() & df_w["masked_elev"].notna()]
                df_w = df_w[["x", "y", "mb", "masked_elev"]]
                df_w["year"] = year
                df_w["glacier"] = glacier_name
                df_w["period"] = "winter"
                all_years_w.append(df_w)

            # --- Load Annual MB ---
            ds_a = load_glamos_wgs84(glacier_name, year, period="annual")
            if ds_a is not None:
                masked_elev_interp = ds_dem["masked_elev"].interp_like(
                    ds_a, method="nearest")
                masked_elev_interp = masked_elev_interp.assign_coords(x=ds_a.x,
                                                                      y=ds_a.y)

                ds_merged_a = xr.merge(
                    [
                        ds_a.to_dataset(name="mb"),
                        masked_elev_interp.to_dataset(name="masked_elev"),
                    ],
                    compat="override",
                )

                df_a = ds_merged_a.to_dataframe().reset_index()
                df_a = df_a[df_a["mb"].notna() & df_a["masked_elev"].notna()]
                df_a = df_a[["x", "y", "mb", "masked_elev"]]
                df_a["year"] = year
                df_a["glacier"] = glacier_name
                df_a["period"] = "annual"
                all_years_a.append(df_a)

        # --- Concatenate per glacier ---
        if all_years_w:
            df_glacier_w = pd.concat(all_years_w, ignore_index=True)
            all_glacier_data_w.append(df_glacier_w)
        if all_years_a:
            df_glacier_a = pd.concat(all_years_a, ignore_index=True)
            all_glacier_data_a.append(df_glacier_a)

    # --- Final combined DataFrames ---
    df_GLAMOS_w = (pd.concat(all_glacier_data_w, ignore_index=True)
                   if all_glacier_data_w else pd.DataFrame())
    df_GLAMOS_a = (pd.concat(all_glacier_data_a, ignore_index=True)
                   if all_glacier_data_a else pd.DataFrame())

    # --- Drop x/y and rename elevation column ---
    if not df_GLAMOS_w.empty:
        df_GLAMOS_w = df_GLAMOS_w.drop(["x", "y", "period"],
                                       axis=1).rename(columns={
                                           "masked_elev": "elevation",
                                           "mb": "apr"
                                       })
    if not df_GLAMOS_a.empty:
        df_GLAMOS_a = df_GLAMOS_a.drop(["x", "y", "period"],
                                       axis=1).rename(columns={
                                           "masked_elev": "elevation",
                                           "mb": "sep"
                                       })

    return df_GLAMOS_w, df_GLAMOS_a


PATH_GLAMOS = os.path.join(cfg.dataPath, path_distributed_MB_glamos, 'GLAMOS')
glaciers = os.listdir(PATH_GLAMOS)

glacier_years = (
    df_months_LSTM.groupby('glacier')['year'].unique().apply(sorted).to_dict())

df_GLAMOS_w, df_GLAMOS_a = load_all_glamos(glacier_years, PATH_GLAMOS)

In [None]:
# --- 1. Glacier-wide annual mean MB per year ---
glwd_months_LSTM = df_months_LSTM.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_GLAMOS_w = df_GLAMOS_w.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_GLAMOS_a = df_GLAMOS_a.groupby(['glacier', 'year']).mean().reset_index()

# --- 2. Compute the intersection of valid glacier–year pairs across all datasets ---
valid_pairs = (
    set(zip(glwd_months_LSTM['glacier'], glwd_months_LSTM['year']))
    & set(zip(glwd_months_GLAMOS_w['glacier'], glwd_months_GLAMOS_w['year']))
    & set(zip(glwd_months_GLAMOS_a['glacier'], glwd_months_GLAMOS_a['year']))
)

# --- 3. Helper function for filtering by glacier–year pairs ---
def filter_to_valid(df):
    return df[df[['glacier', 'year']].apply(tuple, axis=1).isin(valid_pairs)].reset_index(drop=True)

# --- 4. Apply consistent filtering to all datasets ---
glwd_months_LSTM_filtered = filter_to_valid(glwd_months_LSTM)
glwd_months_GLAMOS_filtered_w = filter_to_valid(glwd_months_GLAMOS_w)
glwd_months_GLAMOS_filtered_a = filter_to_valid(glwd_months_GLAMOS_a)

print(
    len(glwd_months_GLAMOS_filtered_w),
    len(glwd_months_GLAMOS_filtered_a),
    len(glwd_months_LSTM_filtered),
)

# --- 5. Prepare for plotting ---
df_months_long = prepare_monthly_long_df(
    glwd_months_LSTM_filtered,
    glwd_months_LSTM_filtered,
    glwd_months_GLAMOS_filtered_w,
    glwd_months_GLAMOS_filtered_a,
)

df_months_long.head(2)


In [None]:
min_, max_ = df_months_long.min()[[
    'mb_nn', 'mb_glamos'
]].min(), df_months_long.max()[['mb_nn', 'mb_glamos']].max()
fig = plot_monthly_joyplot_single(df_months_long,
                                  variable="mb_lstm",
                                  color_model=color_annual,
                                  x_range=(np.floor(min_), np.ceil(max_)))
fig.savefig('figures/CH_LSTM_vs_NN_monthly_joyplot_glwd.png',
            dpi=300,
            bbox_inches='tight')