## Setting Up:

In [None]:
import os
import warnings
import logging
from collections import defaultdict

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from cmcrameri import cm

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

import massbalancemachine as mbm

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.dataset import *
from regions.TF_Europe.scripts.plotting import *
from regions.TF_Europe.scripts.models import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.EuropeConfig()
mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

## Read GL data:

In [None]:
# Load all Central Europe (FR+CH+IT+AT when you add them)
df_ceu = load_stakes_for_rgi_region(cfg, "11")

# Capitalize glacier names:
glacierCap = {}
for gl in df_ceu['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# Print number of total, annual and winter observations:
print("Total observations:", len(df_ceu))
data_annual = df_ceu[df_ceu['PERIOD'] == 'annual']
print("Annual observations:", len(data_annual))
data_winter = df_ceu[df_ceu['PERIOD'] == 'winter']
print("Winter observations:", len(data_winter))

In [None]:
TEST_GLACIERS_CH = [
    "tortin",
    "plattalva",
    "schwarzberg",
    "hohlaub",
    "sanktanna",
    "corvatsch",
    "tsanfleuron",
    "forno",
]

TEST_GLACIERS_IT_AT = [
    'GOLDBERG K.',
    'HINTEREIS F.',
    'JAMTAL F.',
    'VERNAGT F.',
]

TEST_GLACIERS_FR = ['Talefre', 'Argentiere', 'Gebroulaz']

TEST_GLACIERS_ALL = TEST_GLACIERS_CH + TEST_GLACIERS_IT_AT + TEST_GLACIERS_FR

# Glacier outlines:
glacier_outline_rgi = gpd.read_file(
    cfg.dataPath + "RGI_v6/RGI_11_CentralEurope/11_rgi60_CentralEurope.shp")

# get number of measurements per glacier:
glacier_info = df_ceu.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = df_ceu.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = df_ceu.groupby(['GLACIER', 'PERIOD'
                                 ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in TEST_GLACIERS_ALL else 'Train', axis=1)
glacier_info.head(2)

In [None]:
TEST_GLACIERS_BY_CODE = {
    "CH": TEST_GLACIERS_CH,
    "IT_AT": TEST_GLACIERS_IT_AT,
    "FR": TEST_GLACIERS_FR,
}


def print_test_share_by_source(df_ceu, test_glaciers_by_code):
    rows = []

    for code, test_gls in test_glaciers_by_code.items():
        df_src = df_ceu[df_ceu["SOURCE_CODE"] == code].copy()

        if len(df_src) == 0:
            print(f"{code}: no rows in df_ceu")
            continue

        is_test = df_src["GLACIER"].isin(test_gls)

        n_test = int(is_test.sum())
        n_train = int((~is_test).sum())

        # % test relative to train (what you asked)
        pct_test_vs_train = 100 * n_test / n_train if n_train > 0 else float(
            "nan")

        # also useful: % test of total
        pct_test_of_total = 100 * n_test / len(df_src) if len(
            df_src) > 0 else float("nan")

        rows.append({
            "SOURCE_CODE": code,
            "n_total": len(df_src),
            "n_train": n_train,
            "n_test": n_test,
            "test_%_of_total": pct_test_of_total,
        })

        print(
            f"{code}: train={n_train}, test={n_test} | test/total={pct_test_of_total:.1f}%"
        )

    return pd.DataFrame(rows).set_index("SOURCE_CODE").sort_index()


df_shares = print_test_share_by_source(df_ceu, TEST_GLACIERS_BY_CODE)
df_shares

## Intro & methods:

### Geoplots (Fig 1):


In [None]:
# ---- 1. Preprocessing ----
# Square-root scaling of number of measurements
glacier_info['sqrt_size'] = np.sqrt(glacier_info['Nb. measurements'])

# Cache dataset-wide min and max
sqrt_min = glacier_info['sqrt_size'].min()
sqrt_max = glacier_info['sqrt_size'].max()

# Define the desired marker size range in points^2
sizes = (100, 1500)  # min and max scatter size


# Function to scale individual values consistently
def scaled_size(val, min_out=sizes[0], max_out=sizes[1]):
    sqrt_val = np.sqrt(val)
    if sqrt_max == sqrt_min:
        return (min_out + max_out) / 2
    return min_out + (max_out - min_out) * ((sqrt_val - sqrt_min) /
                                            (sqrt_max - sqrt_min))


# Apply scaling to full dataset for the actual plot
glacier_info['scaled_size'] = glacier_info['Nb. measurements'].apply(
    scaled_size)

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

#latN, latS = 48, 45.8
latN, latS = 48, 44
lonW, lonE = 5.5, 14
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)
ax2.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.5)

# Add the image to the cartopy map
# masked_destination = np.ma.masked_where(destination == 0, destination)
# cmap = plt.cm.gray
# cmap.set_bad(color='white')  # Set masked (bad) values to white
# ax2.imshow(
#     masked_destination,
#     origin='upper',
#     extent=extent,
#     transform=ccrs.PlateCarree(),  # Assuming raster is in WGS84
#     cmap=cmap,  # or any other colormap
#     alpha=0.4,  # transparency
#     zorder=0)

# Glacier outlines
glacier_outline_rgi.plot(ax=ax2, transform=projPC, color='black', alpha=0.7)

# ---- 3. Scatterplot ----
# custom_palette = {'Train': '#35978f', 'Test': '#8c510a'}
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

g = sns.scatterplot(
    data=glacier_info,
    x='POINT_LON',
    y='POINT_LAT',
    size='scaled_size',
    hue='Train/Test glacier',
    sizes=sizes,
    alpha=0.6,
    palette=custom_palette,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

# ---- 5. Custom Combined Legend ----

# Hue legend handles
handles, labels = g.get_legend_handles_labels()
expected_labels = list(custom_palette.keys())
hue_entries = [(h, l) for h, l in zip(handles, labels) if l in expected_labels]

# Size legend values and handles
size_values = [30, 100, 1000, 6000]
size_handles = [
    Line2D(
        [],
        [],
        marker='o',
        linestyle='None',
        markersize=np.sqrt(scaled_size(val)),  # matplotlib uses radius
        markerfacecolor='gray',
        alpha=0.6,
        label=f'{val}') for val in size_values
]

# Separator label
separator_handle = Patch(facecolor='none',
                         edgecolor='none',
                         label='Nb. measurements')

# Combine all legend entries
# combined_handles = [h for h, _ in hue_entries] + [separator_handle] + size_handles
# combined_labels = [l for _, l in hue_entries] + ['Nb. measurements'] + [str(v) for v in size_values]

# same but without separator
combined_handles = [h for h, _ in hue_entries] + size_handles
combined_labels = [l for _, l in hue_entries] + [str(v) for v in size_values]

# Final legend
ax2.legend(combined_handles,
           combined_labels,
           title='Number of measurements',
           loc='lower right',
           frameon=True,
           fontsize=18,
           title_fontsize=18,
           borderpad=1.2,
           labelspacing=1.2,
           ncol=3)
# ax2.set_title('Glacier measurement locations', fontsize = 25)
plt.tight_layout()
plt.show()

# save figure
# fig.savefig('figures/paper/fig1_ch_map.png', dpi=300, bbox_inches='tight')

In [None]:
# CONSTANT COLORS FOR PLOTS
colors = get_cmap_hex(cm.batlow, 10)
color_winter = colors[0]
color_annual = "#c51b7d"

fig = plt.figure(figsize=(18, 10))
ax = plt.subplot(1, 1, 1)
# Number of measurements per year:
df_ceu.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack().plot(
    kind='bar',
    stacked=True,
    figsize=(20, 5),
    color=[color_annual, color_winter],
    ax=ax)
# plt.title('Number of measurements per year for all glaciers', fontsize = 25)
# get legend
plt.legend(title='Period', fontsize=18, title_fontsize=20, ncol=2)
# save figure
# fig.savefig('figures/paper/fig1_num_year.png', dpi=300, bbox_inches='tight')

In [None]:
meas_period = df_ceu.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack()
meas_period.sum()

### Input data:

#### Heatmap annual (Fig 2):

##### PMB (Fig 2a):

In [None]:
fig = plot_heatmap(TEST_GLACIERS_ALL,
                   df_ceu,
                   glacierCap,
                   period='annual',
                   cbar_label="Mean PMB [m w.e. $a^{-1}$]")

# save figure
# fig.savefig('figures/paper/fig_heatmap.png', dpi=300, bbox_inches='tight')

In [None]:
fig = plot_heatmap(TEST_GLACIERS_CH,
                   df_ceu[df_ceu.SOURCE_CODE == 'CH'],
                   glacierCap,
                   period='annual',
                   cbar_label="Mean PMB [m w.e. $a^{-1}$]")

# save figure
# fig.savefig('figures/paper/fig_heatmap.png', dpi=300, bbox_inches='tight')

In [None]:
fig = plot_heatmap(TEST_GLACIERS_FR,
                   df_ceu[df_ceu.SOURCE_CODE == 'FR'],
                   glacierCap,
                   period='annual',
                   cbar_label="Mean PMB [m w.e. $a^{-1}$]")

# save figure
# fig.savefig('figures/paper/fig_heatmap.png', dpi=300, bbox_inches='tight')

In [None]:
fig = plot_heatmap(TEST_GLACIERS_IT_AT,
                   df_ceu[df_ceu.SOURCE_CODE == 'IT_AT'],
                   glacierCap,
                   period='annual',
                   cbar_label="Mean PMB [m w.e. $a^{-1}$]")

# save figure
# fig.savefig('figures/paper/fig_heatmap.png', dpi=300, bbox_inches='tight')

## Feature distribution per test glacier:

In [None]:
dfs = {rid: load_stakes_for_rgi_region(cfg, rid) for rid in RGI_REGIONS.keys()}

# Transform data to monthly format (run or load data):
paths = {
    'era5_climate_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_monthly_averaged_data_Europe.nc"),
    'geopotential_data':
    os.path.join(cfg.dataPath, path_ERA5_raw,
                 "era5_geopotential_pressure_Europe.nc")
}

# Check that all these files exists
for key, path in paths.items():
    if not os.path.exists(path):
        raise FileNotFoundError(f"Required file for {key} not found at {path}")

    vois_climate = [
        "t2m",
        "tp",
        "slhf",
        "sshf",
        "ssrd",
        "fal",
        "str",
    ]

vois_topographical = ["aspect", "slope", "svf"]

# Example: Only recompute IT_AT
res_all = prepare_monthly_dfs_for_all_regions(
    cfg=cfg,
    dfs=dfs,
    paths=paths,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    run_flag=False,
    test_glaciers_by_code=TEST_GLACIERS_BY_CODE,
    # only_codes=["IT_AT"],  # only IT_AT recomputes (careful, only codes overrides run_flag)
)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

def plot_feature_overlap_per_test_glacier_vs_train_pool(
    results_dict,
    code: str,
    test_glaciers_by_code: dict,
    feature: str = "ELEVATION_DIFFERENCE",
    use_aug: bool = True,      # True -> df_*_aug, False -> df_*
    bins_fallback: int = 30,   # used if scipy not available
    ncols: int = 3,
    figsize_per_col: float = 5.5,
    figsize_per_row: float = 3.6,
):
    """
    For a given subregion code (e.g. IT_AT), plot per test glacier the feature
    distribution vs the TRAIN pool distribution.

    Each subplot shows:
      - KDE (or hist-density fallback) of the TRAIN pool (all non-test glaciers)
      - KDE (or hist-density fallback) of one TEST glacier

    Parameters
    ----------
    results_dict : dict
        Monthly-prep results dict keyed like "11_IT_AT", containing df_train/df_test
        and optionally df_train_aug/df_test_aug.
    code : str
        Subregion code (e.g. "IT_AT", "CH", "FR", "NOR", ...).
    test_glaciers_by_code : dict
        Mapping code -> list of test glacier names.
    feature : str
        Feature column to plot.
    use_aug : bool
        Use df_train_aug/df_test_aug if True, else df_train/df_test.
    ncols : int
        Number of subplot columns.
    """

    code = code.upper()
    test_glaciers = list(test_glaciers_by_code.get(code, []))
    if not test_glaciers:
        print(f"[warn] No test glaciers defined for code={code}")
        return None

    # find matching keys like "11_IT_AT"
    keys = [k for k in results_dict.keys() if k.endswith(f"_{code}")]
    if not keys:
        raise KeyError(f"No results key endswith _{code} (expected e.g. '11_{code}').")
    if len(keys) > 1:
        print(f"[info] Multiple keys match code={code}: {keys}. Using {keys[0]}.")
    key = keys[0]

    res = results_dict[key]
    if res is None:
        raise ValueError(f"{key} result is None.")

    if use_aug:
        df_train = res.get("df_train_aug")
        df_test  = res.get("df_test_aug")
        df_label = "df_*_aug"
    else:
        df_train = res.get("df_train")
        df_test  = res.get("df_test")
        df_label = "df_*"

    if df_train is None or df_test is None or len(df_train) == 0 or len(df_test) == 0:
        raise ValueError(f"{key}: missing or empty train/test frames for {df_label}.")

    df_all = pd.concat([df_train, df_test], ignore_index=True)

    if "GLACIER" not in df_all.columns:
        raise ValueError(f"{key}: no GLACIER column in {df_label}.")
    if feature not in df_all.columns:
        raise ValueError(f"{key}: feature '{feature}' not in {df_label}.")

    df_all = df_all.dropna(subset=[feature]).copy()
    if len(df_all) == 0:
        print(f"{key}: all values are NaN for {feature}")
        return None

    test_set = set(test_glaciers)

    train_pool = df_all[~df_all["GLACIER"].isin(test_set)][feature].astype(float).values
    if train_pool.size == 0:
        raise ValueError(f"{key}: TRAIN pool empty after excluding test glaciers.")

    # Keep only test glaciers that actually exist in the dataframe
    present_tests = [g for g in test_glaciers if (df_all["GLACIER"] == g).any()]
    missing_tests = [g for g in test_glaciers if g not in present_tests]
    if missing_tests:
        print(f"[warn] {key}: these test glaciers not found in data and will be skipped:\n  {missing_tests}")

    if not present_tests:
        print(f"{key}: none of the listed test glaciers are present in the data.")
        return None

    # Common x-range for comparability across panels
    xmin = float(min(train_pool.min(), df_all[df_all["GLACIER"].isin(present_tests)][feature].min()))
    xmax = float(max(train_pool.max(), df_all[df_all["GLACIER"].isin(present_tests)][feature].max()))
    if np.isclose(xmin, xmax):
        xmin -= 1e-6
        xmax += 1e-6
    xgrid = np.linspace(xmin, xmax, 400)

    # Try KDE via scipy, otherwise fallback to hist-density
    try:
        from scipy.stats import gaussian_kde
        kde_train = gaussian_kde(train_pool)
        train_y = kde_train(xgrid)
        use_kde = True
    except Exception:
        use_kde = False

    n = len(present_tests)
    nrows = math.ceil(n / ncols)

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(figsize_per_col * ncols, figsize_per_row * nrows),
        sharex=True,
        sharey=True,
    )
    axes = np.array(axes).reshape(-1)

    for i, gl in enumerate(present_tests):
        ax = axes[i]
        test_vals = df_all.loc[df_all["GLACIER"] == gl, feature].astype(float).values

        if use_kde:
            kde_test = gaussian_kde(test_vals) if test_vals.size > 1 else None

            ax.plot(xgrid, train_y, label=f"Train pool (n={train_pool.size})")
            if kde_test is not None:
                ax.plot(xgrid, kde_test(xgrid), label=f"{gl} (n={test_vals.size})")
            else:
                # single-value degenerate case: show as a vertical line
                ax.axvline(test_vals[0], linestyle="--", label=f"{gl} (n=1)")
        else:
            # hist density fallback (same bins for train/test)
            bins = np.linspace(xmin, xmax, bins_fallback + 1)
            ax.hist(train_pool, bins=bins, density=True, alpha=0.35, label=f"Train pool (n={train_pool.size})")
            ax.hist(test_vals,  bins=bins, density=True, alpha=0.55, label=f"{gl} (n={test_vals.size})")

        ax.set_title(gl)
        ax.set_xlabel(feature)
        ax.set_ylabel("Density")
        ax.grid(True, alpha=0.2)

        # keep legend small
        ax.legend(fontsize=8)

    for j in range(n, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"{key}: {feature} â€” each TEST glacier vs TRAIN pool ({df_label})", fontsize=14)
    fig.tight_layout()
    plt.show()

    return fig

fig = plot_feature_overlap_per_test_glacier_vs_train_pool(
    results_dict=res_all,
    code="IT_AT",
    test_glaciers_by_code=TEST_GLACIERS_BY_CODE,
    feature="ELEVATION_DIFFERENCE",
    use_aug=True,   # uses df_train_aug/df_test_aug
    ncols=3,
)
