## Setting Up:

In [None]:
import os
import warnings
import logging
from collections import defaultdict

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from cmcrameri import cm

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

import massbalancemachine as mbm

from scripts.utils import *
from scripts.config_CH import *
from scripts.glamos import *
from scripts.dataset import *
from scripts.plotting import *


warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

# Plot styles:
use_mbm_style()

seed_all(cfg.seed)
print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read GL data:

In [None]:
data_glamos = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                          'CH_wgms_dataset_all.csv')

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# Cut to glaciers with pcsr:
glacier_list = [
    'adler', 'albigna', 'aletsch', 'allalin', 'basodino', 'clariden',
    'corbassiere', 'corvatsch', 'findelen', 'forno', 'gietro', 'gorner',
    'gries', 'hohlaub', 'joeri', 'limmern', 'morteratsch', 'murtel', 'oberaar',
    'otemma', 'pizol', 'plattalva', 'rhone', 'sanktanna', 'schwarzbach',
    'schwarzberg', 'sexrouge', 'silvretta', 'tortin', 'tsanfleuron'
]

data_glamos = data_glamos[data_glamos['GLACIER'].isin(glacier_list)]

# Print number of total, annual and winter observations:
print("Total observations:", len(data_glamos))
data_annual = data_glamos[data_glamos['PERIOD'] == 'annual']
print("Annual observations:", len(data_annual))
data_winter = data_glamos[data_glamos['PERIOD'] == 'winter']
print("Winter observations:", len(data_winter))

# Filter to test glaciers:
data_glamos_test = data_glamos[data_glamos['GLACIER'].isin(TEST_GLACIERS)]
data_glamos_test = data_glamos_test[data_glamos_test.YEAR < 2025]
print("Total test observations:", len(data_glamos_test))
data_annual = data_glamos_test[data_glamos_test['PERIOD'] == 'annual']
print("Annual test observations:", len(data_annual))
data_winter = data_glamos_test[data_glamos_test['PERIOD'] == 'winter']
print("Winter test observations:", len(data_winter))
print('Percentage of test data:', len(data_glamos_test)/len(data_glamos)*100)
print('Percentage of annual test data:', len(data_annual)/len(data_glamos[data_glamos['PERIOD'] == 'annual'])*100)
print('Percentage of winter test data:', len(data_winter)/len(data_glamos[data_glamos['PERIOD'] == 'winter'])*100)

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# Glacier outlines:
glacier_outline_sgi = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'SGI_2016_glaciers_copy.shp'))  # Load the shapefile
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

# get number of measurements per glacier:
glacier_info = data_glamos.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = data_glamos.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = data_glamos.groupby(['GLACIER', 'PERIOD'
                                      ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in TEST_GLACIERS else 'Train', axis=1)
glacier_info.head(2)

## Intro & methods:

### Geoplots (Fig 1):


#### sqrt scaling:

In [None]:
# Open the original raster
tif_name = "landesforstinventar-vegetationshoehenmodell_relief_sentinel_2024_2056.tif"
tif_path = os.path.join(cfg.dataPath, 'GLAMOS/RGI/', tif_name)

# Desired output resolution (in degrees)
# Approx. 100 m in degrees: ~0.0009 deg
target_res = 0.0009
output_crs = "EPSG:4326"  # WGS84

with rasterio.open(tif_path) as src:
    # Calculate transform and shape with coarser resolution
    transform, width, height = calculate_default_transform(
        src.crs,
        output_crs,
        src.width,
        src.height,
        *src.bounds,
        resolution=target_res)

    # Set up destination array and metadata
    kwargs = src.meta.copy()
    kwargs.update({
        'crs': output_crs,
        'transform': transform,
        'width': width,
        'height': height
    })

    # Prepare empty destination array
    destination = np.empty((height, width), dtype=src.dtypes[0])

    # Reproject with coarsening
    reproject(
        source=rasterio.band(src, 1),
        destination=destination,
        src_transform=src.transform,
        src_crs=src.crs,
        dst_transform=transform,
        dst_crs=output_crs,
        resampling=Resampling.
        average  # average to reduce noise when downsampling
    )

    extent = [
        transform[2], transform[2] + transform[0] * width,
        transform[5] + transform[4] * height, transform[5]
    ]

In [None]:
# ---- 1. Preprocessing ----
# Square-root scaling of number of measurements
glacier_info['sqrt_size'] = np.sqrt(glacier_info['Nb. measurements'])

# Cache dataset-wide min and max
sqrt_min = glacier_info['sqrt_size'].min()
sqrt_max = glacier_info['sqrt_size'].max()

# Define the desired marker size range in points^2
sizes = (100, 1500)  # min and max scatter size


# Function to scale individual values consistently
def scaled_size(val, min_out=sizes[0], max_out=sizes[1]):
    sqrt_val = np.sqrt(val)
    if sqrt_max == sqrt_min:
        return (min_out + max_out) / 2
    return min_out + (max_out - min_out) * ((sqrt_val - sqrt_min) /
                                            (sqrt_max - sqrt_min))


# Apply scaling to full dataset for the actual plot
glacier_info['scaled_size'] = glacier_info['Nb. measurements'].apply(
    scaled_size)

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

#latN, latS = 48, 45.8
latN, latS = 47.1, 45.8
lonW, lonE = 5.8, 10.5
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
# ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

# Add the image to the cartopy map

masked_destination = np.ma.masked_where(destination == 0, destination)
cmap = plt.cm.gray
cmap.set_bad(color='white')  # Set masked (bad) values to white
ax2.imshow(
    masked_destination,
    origin='upper',
    extent=extent,
    transform=ccrs.PlateCarree(),  # Assuming raster is in WGS84
    cmap=cmap,  # or any other colormap
    alpha=0.4,  # transparency
    zorder=0)

# Glacier outlines
glacier_outline_sgi.plot(ax=ax2, transform=projPC, color='black', alpha=0.7)

# ---- 3. Scatterplot ----
# custom_palette = {'Train': '#35978f', 'Test': '#8c510a'}
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

g = sns.scatterplot(
    data=glacier_info,
    x='POINT_LON',
    y='POINT_LAT',
    size='scaled_size',
    hue='Train/Test glacier',
    sizes=sizes,
    alpha=0.6,
    palette=custom_palette,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

# ---- 5. Custom Combined Legend ----

# Hue legend handles
handles, labels = g.get_legend_handles_labels()
expected_labels = list(custom_palette.keys())
hue_entries = [(h, l) for h, l in zip(handles, labels) if l in expected_labels]

# Size legend values and handles
size_values = [30, 100, 1000, 6000]
size_handles = [
    Line2D(
        [],
        [],
        marker='o',
        linestyle='None',
        markersize=np.sqrt(scaled_size(val)),  # matplotlib uses radius
        markerfacecolor='gray',
        alpha=0.6,
        label=f'{val}') for val in size_values
]

# Separator label
separator_handle = Patch(facecolor='none',
                         edgecolor='none',
                         label='Nb. measurements')

# Combine all legend entries
# combined_handles = [h for h, _ in hue_entries] + [separator_handle] + size_handles
# combined_labels = [l for _, l in hue_entries] + ['Nb. measurements'] + [str(v) for v in size_values]

# same but without separator
combined_handles = [h for h, _ in hue_entries] + size_handles
combined_labels = [l for _, l in hue_entries] + [str(v) for v in size_values]

# Final legend
ax2.legend(combined_handles,
           combined_labels,
           title='Number of measurements',
           loc='lower right',
           frameon=True,
           fontsize=18,
           title_fontsize=18,
           borderpad=1.2,
           labelspacing=1.2,
           ncol=3)
# ax2.set_title('Glacier measurement locations', fontsize = 25)
plt.tight_layout()
plt.show()

# save figure
fig.savefig('figures/paper/fig1_ch_map.png', dpi=300, bbox_inches='tight')

In [None]:
# CONSTANT COLORS FOR PLOTS
colors = get_cmap_hex(cm.batlow, 10)
color_winter = colors[0]
color_annual = "#c51b7d"

fig = plt.figure(figsize=(18, 10))
ax = plt.subplot(1, 1, 1)
# Number of measurements per year:
data_glamos.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack().plot(
    kind='bar',
    stacked=True,
    figsize=(20, 5),
    color=[color_annual, color_winter],
    ax=ax)
# plt.title('Number of measurements per year for all glaciers', fontsize = 25)
# get legend
plt.legend(title='Period', fontsize=18, title_fontsize=20, ncol=2)
# save figure
fig.savefig('figures/paper/fig1_num_year.png', dpi=300, bbox_inches='tight')

In [None]:
meas_period = data_glamos.groupby(['YEAR',
                                   'PERIOD']).count()['POINT_ID'].unstack()
meas_period.sum()

### Input data:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=VOIS_CLIMATE,
    vois_topographical=VOIS_TOPOGRAPHICAL,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)
months_head_pad, months_tail_pad = mbm.data_processing.utils.build_head_tail_pads_from_monthly_df(
    data_monthly)

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('Train:')
print('Number of winter and annual samples:', len(data_train))
print('Number of annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of winter and annual samples:', len(data_test))
print('Number of annual samples:', len(data_test_annual))
print('Number of winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

#### Heatmap annual (Fig 2):

##### SMB:

In [None]:
def parse_glwd_csv(path):
    rows = []
    with open(path, encoding="utf-8") as f:
        for i, line in enumerate(f):

            # skip first line entirely
            if i == 0:
                continue

            parts = [p.strip() for p in line.strip().split(",")]

            # first 13 fixed columns
            fixed = parts[:13]

            # remaining text fields (could be 0,1,2,n)
            tail = parts[13:]
            tail = tail + [""] * (3 - len(tail))  # pad
            tail = tail[:3]  # trim extra, just in case

            rows.append(fixed + tail)

    cols = [
        "glacier",
        "glacier_id",
        "date_start",
        "date_end_winter",
        "date_end",
        "Bw",
        "Bs",
        "Ba",
        "ELA",
        "AAR",
        "area",
        "h_min",
        "h_max",
        "source_observer",
        "observer",
        "source",
    ]

    return pd.DataFrame(rows, columns=cols)


path = os.path.join(cfg.dataPath, 'GLAMOS/glacier-wide',
                    'massbalance_observation_2025_r2025.csv')
glwd_csv = parse_glwd_csv(path)
glwd_csv.head()

sgi_id = rgi_df.loc['morteratsch']['sgi-id']
print(glwd_csv[glwd_csv['glacier_id'] == sgi_id])

sgi_id = rgi_df.loc['gorner']['sgi-id']
print(glwd_csv[glwd_csv['glacier_id'] == sgi_id])

all_glwd = []  # list of dataframes to concat
for gl in data_glamos.GLACIER.unique():
    # if path does not exist, skip
    if not os.path.exists(
            os.path.join(cfg.dataPath, 'GLAMOS/glacier-wide/csv/obs',
                         f'{gl}_obs.csv')):
        print(f"Warning: Glacier CSV for {gl} not found. Skipping.")
        continue
    # read individual glacier CSV
    df = pd.read_csv(
        os.path.join(cfg.dataPath, 'GLAMOS/glacier-wide/csv/obs',
                     f'{gl}_obs.csv'))

    # keep only years >= 1951
    df = df[df.YEAR >= 1951].copy()

    # compute MB in m w.e.
    df["MB"] = df["Annual.Balance"] / 1000

    # add a GLACIER column holding the name
    df["GLACIER"] = gl

    # keep only what we need
    df = df[["YEAR", "GLACIER", "MB"]]

    all_glwd.append(df)

# combine into single dataframe
glwd_all = pd.concat(all_glwd, ignore_index=True)

data_glamos_compl = pd.merge(data_glamos,
                             glwd_all,
                             on=['GLACIER', 'YEAR'],
                             how='inner')
fig = plot_heatmap(TEST_GLACIERS,
                  data_glamos_compl,
                  glacierCap,
                  period='annual',
                  var_to_plot='MB')


##### PMB (Fig 2a):

In [None]:
fig = plot_heatmap(TEST_GLACIERS,
                  data_glamos,
                  glacierCap,
                  period='annual',
                  cbar_label="Mean PMB [m w.e. $a^{-1}$]")

# save figure
fig.savefig('figures/paper/fig_heatmap.png', dpi=300, bbox_inches='tight')

In [None]:
# Work on a copy to avoid chained-assignment warnings
data_annual = data_glamos.loc[data_glamos.PERIOD == 'annual'].copy()

# Parse FROM/TO into proper datetimes
data_annual['FROM_date'] = pd.to_datetime(data_annual['FROM_DATE'].astype(str),
                                          format='%Y%m%d',
                                          errors='coerce')

# Work on a copy to avoid chained-assignment warnings
data_winter = data_glamos.loc[data_glamos.PERIOD == 'winter'].copy()

# Parse FROM/TO into proper datetimes
data_winter['TO_date'] = pd.to_datetime(data_winter['TO_DATE'].astype(str),
                                        format='%Y%m%d',
                                        errors='coerce')


# Helper: compute mean date and std (in days) for a date Series, using a dummy year (2000)
def mean_date_and_std(date_series: pd.Series, circular: bool = False):
    # Drop NaT
    s = date_series.dropna()
    if s.empty:
        return pd.NaT, np.nan

    # Map to a dummy, fixed year so we can compute day-of-year
    dummy = pd.to_datetime({
        'year': 2000,
        'month': s.dt.month,
        'day': s.dt.day
    },
                           errors='coerce').dropna()

    doy = dummy.dt.dayofyear.astype(float)

    if not circular:
        mean_doy = doy.mean()
        std_doy = doy.std()
    else:
        # Circular mean/std over the year (useful if dates wrap around New Year)
        theta = 2 * np.pi * (doy - 1) / 365.0
        C = np.mean(np.cos(theta))
        S = np.mean(np.sin(theta))
        mean_ang = np.arctan2(S, C)
        if mean_ang < 0:
            mean_ang += 2 * np.pi
        mean_doy = (mean_ang / (2 * np.pi)) * 365.0 + 1
        R = np.sqrt(C**2 + S**2)
        # Convert circular std (radians) to days
        std_ang = np.sqrt(-2 * np.log(max(R, 1e-12)))
        std_doy = std_ang * 365.0 / (2 * np.pi)

    mean_date = pd.Timestamp('2000-01-01') + pd.to_timedelta(mean_doy - 1,
                                                             unit='D')
    return mean_date, std_doy

# Compute stats for FROM and TO
mean_from_annual, std_from_annual = mean_date_and_std(data_annual['FROM_date'],
                                                      circular=False)
mean_from_winter, std_from_winter = mean_date_and_std(data_winter['TO_date'],
                                                      circular=False)

print(
    f"ANNUAL FROM_DATE -> mean: {mean_from_annual.strftime('%m-%d')} | std: {std_from_annual:.2f} days"
)
print(
    f"WINTER TO_DATE   -> mean: {mean_from_winter.strftime('%m-%d')} | std: {std_from_winter:.2f} days"
)

#### Elevations (Fig 2b):

In [None]:
min_elev = data_glamos.POINT_ELEVATION.min()
max_elev = data_glamos.POINT_ELEVATION.max()

rows_min = data_glamos[data_glamos.POINT_ELEVATION == min_elev]
rows_max = data_glamos[data_glamos.POINT_ELEVATION == max_elev]

print('Min elevation measurement:', min_elev, 'on glacier',
      rows_min.GLACIER.values[0])
print('Max elevation measurement:', max_elev, 'on glacier',
      rows_max.GLACIER.values[0])

# Mean, min and max PMB:
# For annual only
data_glamos_a = data_glamos[data_glamos.PERIOD == 'annual']
mean_pmb = data_glamos_a.POINT_BALANCE.mean()
min_pmb = data_glamos_a.POINT_BALANCE.min()
max_pmb = data_glamos_a.POINT_BALANCE.max()

rows_min = data_glamos_a[data_glamos_a.POINT_BALANCE == min_pmb]
rows_max = data_glamos_a[data_glamos_a.POINT_BALANCE == max_pmb]

print('Annual:')
print('Mean PMB (m w.e.): {:.2f}'.format(mean_pmb))
print('Min PMB (m w.e.): {:.2f}'.format(min_pmb), 'on glacier',
      rows_min.GLACIER.values[0], 'in', rows_min.YEAR.values[0])
print('Max PMB (m w.e.): {:.2f}'.format(max_pmb), 'on glacier',
      rows_max.GLACIER.values[0], 'in', rows_max.YEAR.values[0])

data_glamos_w = data_glamos[data_glamos.PERIOD == 'winter']
mean_pmb = data_glamos_w.POINT_BALANCE.mean()
min_pmb = data_glamos_w.POINT_BALANCE.min()
max_pmb = data_glamos_w.POINT_BALANCE.max()

rows_min = data_glamos_w[data_glamos_w.POINT_BALANCE == min_pmb]
rows_max = data_glamos_w[data_glamos_w.POINT_BALANCE == max_pmb]

print('Winter:')
print('Mean PMB (m w.e.): {:.2f}'.format(mean_pmb))
print('Min PMB (m w.e.): {:.2f}'.format(min_pmb), 'on glacier',
      rows_min.GLACIER.values[0], 'in', rows_min.YEAR.values[0])
print('Max PMB (m w.e.): {:.2f}'.format(max_pmb), 'on glacier',
      rows_max.GLACIER.values[0], 'in', rows_max.YEAR.values[0])

# get elevation of glaciers:
# gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
#     ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = data_glamos.groupby(['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

# Plot elevation:
fig = plt.figure(figsize=(10, 2))
ax = plt.subplot(1, 1, 1)
sns.lineplot(gl_per_el.sort_values(ascending=True),
             ax=ax,
             color='gray',
             marker='v')
ax.set_xticklabels('', rotation=90)
ax.set_ylabel('')
ax.set_xlabel('')

In [None]:
gl_per_el

## Winter MB before and after 2000:

In [None]:
from scipy.stats import gaussian_kde

# Identify training glaciers
train_glaciers = data_glamos.GLACIER.unique()
train_glaciers = [g for g in train_glaciers if g not in TEST_GLACIERS]

# Split dataset
train = data_glamos[data_glamos.GLACIER.isin(train_glaciers) & (data_glamos.PERIOD == 'winter')]
test  = data_glamos[data_glamos.GLACIER.isin(TEST_GLACIERS)      & (data_glamos.PERIOD == 'winter')]

# Before / after 2003 splits
train_bef = train[train.YEAR < 2003]
train_aft = train[train.YEAR >= 2003]

test_bef = test[test.YEAR < 2003]
test_aft = test[test.YEAR >= 2003]

# Extract POINT_BALANCE
train_bef_x = train_bef['POINT_BALANCE'].dropna()
train_aft_x = train_aft['POINT_BALANCE'].dropna()

test_bef_x = test_bef['POINT_BALANCE'].dropna()
test_aft_x = test_aft['POINT_BALANCE'].dropna()

# KDEs
train_kde_bef = gaussian_kde(train_bef_x)
train_kde_aft = gaussian_kde(train_aft_x)

test_kde_bef = gaussian_kde(test_bef_x)
test_kde_aft = gaussian_kde(test_aft_x)

# Combined grid
xmin = min(train_bef_x.min(), train_aft_x.min(),
           test_bef_x.min(), test_aft_x.min())

xmax = max(train_bef_x.max(), train_aft_x.max(),
           test_bef_x.max(), test_aft_x.max())

xgrid = np.linspace(xmin, xmax, 500)

# Plot
fig, axs = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

# TRAIN PANEL
axs[0].plot(xgrid, train_kde_bef(xgrid), label='Train < 2003', linewidth=2)
axs[0].plot(xgrid, train_kde_aft(xgrid), label='Train ≥ 2003', linewidth=2)
axs[0].set_title("Training Glaciers")
axs[0].set_xlabel("POINT_BALANCE")
axs[0].set_ylabel("Density")
axs[0].grid(alpha=0.3)
axs[0].legend()

# TEST PANEL
axs[1].plot(xgrid, test_kde_bef(xgrid), label='Test < 2003', linewidth=2)
axs[1].plot(xgrid, test_kde_aft(xgrid), label='Test ≥ 2003', linewidth=2)
axs[1].set_title("Test Glaciers")
axs[1].set_xlabel("POINT_BALANCE")
axs[1].grid(alpha=0.3)
axs[1].legend()

plt.tight_layout()
plt.show()