## Setting up:

In [None]:
import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
import massbalancemachine as mbm
import pyproj
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from cmcrameri import cm
from oggm import utils

# from scripts.helpers import *
from scripts.norway_preprocess import *
from scripts.config_NOR import *

from regions.Switzerland.scripts.oggm import initialize_oggm_glacier_directories, export_oggm_grids
from regions.Switzerland.scripts.glamos import merge_pmb_with_oggm_data, rename_stakes_by_elevation, check_point_ids_contain_glacier, remove_close_points, check_multiple_rgi_ids

from regions.French_Alps.scripts.glacioclim_preprocess import add_svf_from_rgi_zarr, plot_missing_svf_for_all_glaciers, add_svf_nearest_valid
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.NorwayConfig()

mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

## Pre-processing:
Load stakes, fill missing start dates, split into winter and annual and transform to WGMS format. Dataset acquired from https://doi.org/10.58059/sjse-6w92

### Process dates:

In [None]:
df_stakes = pd.read_csv(
    os.path.join(cfg.dataPath, path_PMB_WGMS_raw,
                 'glaciological_point_mass_balance_Norway.csv'))
df_stakes = df_stakes.rename(columns={'rgiid': 'RGIId'})

# Add data modification column to keep track of mannual changes
df_stakes['DATA_MODIFICATION'] = ''

# FROM_DATE is missing in some glaciers despite having pmb measurements, fill with start of hydr. year
df_stakes = fill_missing_dates(df_stakes)

# Split into winter and annual measurements
df_stakes = split_stake_measurements(df_stakes)

# Transform to WGMS format
df_stakes = df_stakes.rename(
    columns={
        'lat': 'POINT_LAT',
        'lon': 'POINT_LON',
        'altitude': 'POINT_ELEVATION',
        'breid': 'GLACIER',
    })
# Only keep relevant columns in df
df_stakes = df_stakes[[
    'POINT_LAT', 'POINT_LON', 'POINT_ELEVATION', 'FROM_DATE', 'TO_DATE',
    'POINT_BALANCE', 'PERIOD', 'RGIId', 'YEAR', 'GLACIER', 'DATA_MODIFICATION',
    'approx_loc', 'approx_altitude'
]]

# Convert datetime to yyyymmdd
df_stakes['FROM_DATE'] = pd.to_datetime(
    df_stakes['FROM_DATE'], format='%d.%m.%Y').dt.strftime('%Y%m%d')
df_stakes['TO_DATE'] = pd.to_datetime(df_stakes['TO_DATE'],
                                      format='%d.%m.%Y').dt.strftime('%Y%m%d')

df_stakes.head()

### Add glacier names from RGIId

In [None]:
# initialize OGGM glacier directories
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="08",
    rgi_version="62",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L1-L2_files/2025.6/elev_bands_w_data/",
    log_level='WARNING',
    task_list=None,
)

export_oggm_grids(cfg, gdirs, rgi_region="08")

# Create a dictionary mapping from RGIId to glacier name
rgi_to_name_dict = dict(zip(rgidf.RGIId, rgidf.Name))
df_stakes['GLACIER'] = df_stakes['RGIId'].map(rgi_to_name_dict)

# RGI60-08.02966 has no glacier name in the RGI map so directly give it name Blåbreen
df_stakes.loc[df_stakes['GLACIER'].isna(), 'GLACIER'] = 'Blabreen'

### Create unique POINT_ID:

In [None]:
# Create new POINT_ID column
df_stakes['POINT_ID'] = (df_stakes['GLACIER'] + '_' +
                         df_stakes['YEAR'].astype(str) + '_' +
                         df_stakes['PERIOD'].astype(str) + '_' +
                         df_stakes['POINT_LAT'].astype(str) + '_' +
                         df_stakes['POINT_LON'].astype(str) + '_' +
                         df_stakes['approx_loc'].astype(str) + '_' +
                         df_stakes['approx_altitude'].astype(str) + '_' +
                         df_stakes.index.astype(str))

# Drop columns that are not needed anymore
df_stakes = df_stakes.drop(columns=['approx_loc', 'approx_altitude'])

### Fix wrong date ranges:

In [None]:
annual_inconsistent, winter_inconsistent = check_period_consistency(df_stakes)

pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
if len(annual_inconsistent) > 0:
    print("\nInconsistent annual periods:")
    display(annual_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

if len(winter_inconsistent) > 0:
    print("\nInconsistent winter periods:")
    display(winter_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

First fix is to switch all the months that have been wrongfully recorded as 01 instead of 10:

In [None]:
# This function corrects the dates where 01 (Jan) has been entered as the month instead of 10 (Oct)
df_stakes_dates_fix = fix_january_to_october_dates(df_stakes,
                                                   annual_inconsistent,
                                                   winter_inconsistent)

annual_inconsistent, winter_inconsistent = check_period_consistency(
    df_stakes_dates_fix)

if len(annual_inconsistent) > 0:
    print("\nInconsistent annual periods:")
    display(annual_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

if len(winter_inconsistent) > 0:
    print("\nInconsistent winter periods:")
    display(winter_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

Second fix is some by hand and the rest are wrong years:

In [None]:
## Fix outliers that don't have common explanation by hand
# May instead of september
df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Svartisheibreen_1994_annual_66.55012_13.72724_N_N_883',
    ['TO_DATE', 'DATA_MODIFICATION']] = [
        '19940915', 'Changed TO_DATE month from May to September'
    ]
df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Svartisheibreen_1994_annual_66.54826_13.73128_N_N_884',
    ['TO_DATE', 'DATA_MODIFICATION']] = [
        '19940915', 'Changed TO_DATE month from May to September'
    ]
# TO_DATE annual wrong year
df_stakes_dates_fix.loc[df_stakes_dates_fix['POINT_ID'] ==
                        'Aalfotbreen_1974_annual_61.74236_5.64623_N_N_1386',
                        ['TO_DATE', 'DATA_MODIFICATION']] = [
                            '19740920',
                            'Changed TO_DATE year from 1975 to 1974'
                        ]
df_stakes_dates_fix.loc[df_stakes_dates_fix['POINT_ID'] ==
                        'Aalfotbreen_1971_annual_61.75213_5.63165_N_N_1493',
                        ['TO_DATE', 'DATA_MODIFICATION']] = [
                            '19711124',
                            'Changed TO_DATE year from 1970 to 1971'
                        ]
df_stakes_dates_fix.loc[df_stakes_dates_fix['POINT_ID'] ==
                        'Graafjellsbrea_2009_annual_60.06923_6.38925_N_N_3545',
                        ['TO_DATE', 'DATA_MODIFICATION']] = [
                            '20091013',
                            'Changed TO_DATE year from 2019 to 2009'
                        ]
df_stakes_dates_fix.loc[df_stakes_dates_fix['POINT_ID'] ==
                        'Bondhusbrea_1981_annual_60.03108_6.31014_N_N_3738',
                        ['TO_DATE', 'DATA_MODIFICATION']] = [
                            '19810827',
                            'Changed TO_DATE year fomr 1980 to 1981'
                        ]
# TO_DATE winter wrong year
df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Langfjordjoekulen_2019_winter_70.12528_21.71827_N_N_4019',
    ['TO_DATE', 'DATA_MODIFICATION', 'YEAR', 'POINT_ID']] = [
        '20200526', 'Changed TO_DATE year fomr 2019 to 2020', '2020',
        'Langfjordjoekulen_2020_winter_70.12528_21.71827_N_N_4019'
    ]

df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Blaaisen_1966_winter_68.33479_17.85005_N_N_4155',
    ['TO_DATE', 'DATA_MODIFICATION', 'YEAR', 'POINT_ID']] = [
        '19670520', 'Changed TO_DATE year fomr 1966 to 1967', '1967',
        'Blaaisen_1967_winter_68.33479_17.85005_N_N_4155'
    ]

df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Nigardsbreen_1963_winter_61.71461_7.11601_N_N_5802',
    ['TO_DATE', 'DATA_MODIFICATION', 'YEAR', 'POINT_ID']] = [
        '19640507', 'Changed TO_DATE year fomr 1963 to 1964', '1964',
        'Nigardsbreen_1964_winter_61.71461_7.11601_N_N_5802'
    ]

df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Vesledalsbreen_1967_winter_61.84804_7.25335_N_N_6694',
    ['TO_DATE', 'DATA_MODIFICATION', 'YEAR', 'POINT_ID']] = [
        '19680418', 'Changed TO_DATE year fomr 1967 to 1968', '1968',
        'Vesledalsbreen_1968_winter_61.84804_7.25335_N_N_6694'
    ]

df_stakes_dates_fix.loc[
    df_stakes_dates_fix['POINT_ID'] ==
    'Hellstugubreen_2010_winter_61.57329_8.44438_N_N_6935',
    ['TO_DATE', 'DATA_MODIFICATION', 'YEAR', 'POINT_ID']] = [
        '20110505', 'Changed TO_DATE year fomr 2010 to 2011', '2011',
        'Hellstugubreen_2011_winter_61.57329_8.44438_N_N_6935'
    ]

# These stakes have nonsensical periods, remove them out of df and index list
stakes_to_remove = [
    'Austdalsbreen_2017_annual_61.81113_7.36766_Y_N_3038',
    'Austdalsbreen_2017_annual_61.80888_7.38239_Y_N_3065',
    'Aalfotbreen_1967_winter_61.74294_5.6365_N_N_5379',
    'Hansebreen_2012_winter_61.74307_5.66278_N_N_5625',
    'Austdalsbreen_2017_winter_61.81113_7.36766_Y_N_6792',
    'Austdalsbreen_2017_winter_61.80888_7.38239_Y_N_6819'
]
df_stakes_dates_fix = df_stakes_dates_fix[~df_stakes_dates_fix['POINT_ID'].
                                          isin(stakes_to_remove)]

annual_inconsistent, winter_inconsistent = check_period_consistency(
    df_stakes_dates_fix)

if len(annual_inconsistent) > 0:
    print("\nInconsistent annual periods:")
    display(annual_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

if len(winter_inconsistent) > 0:
    print("\nInconsistent winter periods:")
    display(winter_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])
pd.reset_option('display.max_rows')
pd.reset_option('display.max_colwidth')


Remaining inconsistencies are all wrong FROM_DATE year

In [None]:
remaining_indices = list(annual_inconsistent.index) + list(
    winter_inconsistent.index)

# For each remaining inconsistent record, change the year in FROM_DATE to the previous year
for idx in remaining_indices:
    # Get year from the YEAR column
    year = int(df_stakes_dates_fix.loc[idx, 'YEAR']) - 1

    # Extract month and day part from current FROM_DATE (keeping positions 4-8 which contain MMDD)
    month_day = df_stakes_dates_fix.loc[idx, 'FROM_DATE'][4:8]

    # Create new FROM_DATE by combining YEAR with the extracted month_day
    df_stakes_dates_fix.loc[idx, 'FROM_DATE'] = f"{year}{month_day}"
    df_stakes_dates_fix.loc[
        idx,
        'DATA_MODIFICATION'] = "Changed faulty year in FROM_DATE to previous year of TO_DATE"

annual_inconsistent, winter_inconsistent = check_period_consistency(
    df_stakes_dates_fix)

# Display the inconsistent records
if len(annual_inconsistent) > 0:
    print("\nInconsistent annual periods:")
    display(annual_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])

if len(winter_inconsistent) > 0:
    print("\nInconsistent winter periods:")
    display(winter_inconsistent[[
        'GLACIER', 'FROM_DATE', 'TO_DATE', 'MONTH_DIFF', 'PERIOD', 'YEAR',
        'RGIId', 'POINT_ID'
    ]])
pd.reset_option('display.max_rows')
pd.reset_option('display.max_colwidth')


### Merge close stakes:

In [None]:
# df_stakes_merged = remove_close_points(df_stakes_dates_fix)
df_stakes_merged = pd.DataFrame()
for gl in tqdm(df_stakes_dates_fix.GLACIER.unique(), desc='Merging stakes'):
    print(f'-- {gl.capitalize()}:')
    df_gl = df_stakes_dates_fix[df_stakes_dates_fix.GLACIER == gl]
    df_gl_cleaned = remove_close_points(df_gl)
    df_stakes_merged = pd.concat([df_stakes_merged, df_gl_cleaned])
df_stakes_merged.drop(['x', 'y'], axis=1, inplace=True)
df_stakes_merged.reset_index(inplace=True, drop=True)

### Add OGGM data:

In [None]:
unique_rgis = df_stakes_merged['RGIId'].unique()
df_pmb_topo = merge_pmb_with_oggm_data(df_pmb=df_stakes_merged,
                                       gdirs=gdirs,
                                       rgi_region="08",
                                       rgi_version="62")

In [None]:
# Example:
glacierName = 'Langfjordjoekulen'
# stakes
df_pmb_topo_1 = df_pmb_topo.copy()
df_pmb_topo_1 = df_pmb_topo_1[(df_pmb_topo_1['GLACIER'] == glacierName)]
RGIId = df_pmb_topo_1['RGIId'].unique()[0]
print(RGIId)
# open OGGM xr for glacier
# Get oggm data for that RGI grid
ds_oggm = xr.open_dataset(
    f'{cfg.dataPath + "OGGM/rgi_region_08/xr_grids/"}/{RGIId}.zarr')

# Define the coordinate transformation
transf = pyproj.Transformer.from_proj(
    pyproj.CRS.from_user_input("EPSG:4326"),  # Input CRS (WGS84)
    pyproj.CRS.from_user_input(ds_oggm.pyproj_srs),  # Output CRS from dataset
    always_xy=True)

# Transform all coordinates in the group
lon, lat = df_pmb_topo_1["POINT_LON"].values, df_pmb_topo_1["POINT_LAT"].values
x_stake, y_stake = transf.transform(lon, lat)
df_pmb_topo_1['x'] = x_stake
df_pmb_topo_1['y'] = y_stake

# plot stakes
plt.figure(figsize=(8, 6))
ds_oggm.glacier_mask.plot(cmap='binary')
sns.scatterplot(df_pmb_topo_1,
                x='x',
                y='y',
                hue='within_glacier_shape',
                palette=['r', 'b'])
plt.title(f'Stakes on {glacierName} (OGGM)')
plt.tight_layout()

Only keep glaciers within RGIId shape and drop rows with NaN values anywhere

In [None]:
# --- Initial size ---
n_start = len(df_pmb_topo)

# Restrict to within glacier shape
mask_within = df_pmb_topo['within_glacier_shape'] == True
n_drop_shape = (~mask_within).sum()
df_pmb_topo = df_pmb_topo.loc[mask_within].copy()
df_pmb_topo = df_pmb_topo.drop(columns=['within_glacier_shape'])

print(f"Dropped {n_drop_shape} points outside glacier shape")

# Drop rows with NaN in consensus_ice_thickness
mask_nan_thick = df_pmb_topo['consensus_ice_thickness'].isna()
n_drop_thick = mask_nan_thick.sum()
df_pmb_topo = df_pmb_topo.loc[~mask_nan_thick].copy()

print(f"Dropped {n_drop_thick} points with NaN consensus_ice_thickness")

# --- Final counts ---
print(f"Total dropped: {n_start - len(df_pmb_topo)}")
print('Number of winter and annual samples:', len(df_pmb_topo))
print('Number of annual samples:',
      len(df_pmb_topo[df_pmb_topo.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_topo[df_pmb_topo.PERIOD == 'winter']))

# Unique glaciers, sorted
glacier_list = sorted(df_pmb_topo.GLACIER.unique())
print(f"Number of glaciers: {len(glacier_list)}")
print(f"Glaciers: {glacier_list}")

In [None]:
# Number of measurements per year:
fig, axs = plt.subplots(2, 1, figsize=(20, 15))
ax = axs.flatten()[0]
df_pmb_topo.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
    ax=ax)
ax.set_title('Number of measurements per year for all glaciers')

ax = axs.flatten()[1]
num_gl = df_pmb_topo.groupby(['GLACIER']).size().sort_values()
num_gl.plot(kind='bar', ax=ax)
ax.set_title('Number of total measurements per glacier since 1951')
plt.tight_layout()

### Check for wrong elevation:

In [None]:
df_checked, df_bad = flag_elevation_mismatch(df_pmb_topo, threshold=400)

### Add Skyview factor:

In [None]:
# Example of one svf file
rgi_id = df_pmb_topo.loc[0].RGIId

rgi_gl = "RGI60-08.00010"

# read ds with svf
path_masked_xr = os.path.join(cfg.dataPath,
                              'RGI_v6/RGI_08_Scandinavia/xr_masked_grids/')

xr.open_zarr(path_masked_xr + f'{rgi_gl}.zarr').svf.plot()

In [None]:
path_masked_xr = os.path.join(cfg.dataPath,
                              "RGI_v6/RGI_08_Scandinavia/xr_masked_grids")

df_pmb_topo_svf = add_svf_from_rgi_zarr(
    df_pmb_topo,
    path_masked_xr,
    rgi_col="RGIId",
    lon_col="POINT_LON",
    lat_col="POINT_LAT",
    svf_var="svf",
    out_col="svf",
)
df_missing = df_pmb_topo_svf[df_pmb_topo_svf["svf"].isna()].copy()
print("Missing SVF points:", len(df_missing))
print("Glaciers affected:", sorted(df_missing["RGIId"].unique()))

In [None]:
plot_missing_svf_for_all_glaciers(
    df_with_svf=df_pmb_topo_svf,
    path_masked_xr=path_masked_xr,
    plot_valid_points=True,
    save_dir=
    None  # or e.g. os.path.join(cfg.dataPath, "diagnostics/svf_missing")
)

In [None]:
df_pmb_topo_svf_new = add_svf_nearest_valid(
    df_pmb_topo,
    path_masked_xr,
    rgi_col="RGIId",
    lon_col="POINT_LON",
    lat_col="POINT_LAT",
    svf_var="svf",
    out_col="svf",
    max_radius=30,  # ~30 grid cells search; adjust if needed
)

print("Missing SVF points after nearest-valid fill:",
      df_pmb_topo_svf_new["svf"].isna().sum())

plot_missing_svf_for_all_glaciers(
    df_with_svf=df_pmb_topo_svf_new,
    path_masked_xr=path_masked_xr,
    plot_valid_points=True,
    save_dir=
    None  # or e.g. os.path.join(cfg.dataPath, "diagnostics/svf_missing")
)

### Give new stake IDS:

In [None]:
df_pmb_new_ids = rename_stakes_by_elevation(df_pmb_topo_svf_new)

# Check the condition
check_point_ids_contain_glacier(df_pmb_new_ids)

print('Number of winter and annual samples:', len(df_pmb_new_ids))
print('Number of annual samples:',
      len(df_pmb_new_ids[df_pmb_new_ids.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_new_ids[df_pmb_new_ids.PERIOD == 'winter']))

# Histogram of mass balance
df_pmb_new_ids['POINT_BALANCE'].hist(bins=20)
plt.xlabel('Mass balance [m w.e.]')

## Final cleaning:

In [None]:
df_pmb_clean = df_pmb_new_ids.copy()

# Ensure YYYYMMDD format
df_pmb_clean["FROM_DATE"] = df_pmb_clean["FROM_DATE"].astype(str).str.zfill(8)
df_pmb_clean["TO_DATE"] = df_pmb_clean["TO_DATE"].astype(str).str.zfill(8)

# Extract months
df_pmb_clean["MONTH_START"] = df_pmb_clean["FROM_DATE"].str[4:6]
df_pmb_clean["MONTH_END"] = df_pmb_clean["TO_DATE"].str[4:6]


def print_months(df, label):
    winter = df[df.PERIOD == "winter"]
    annual = df[df.PERIOD == "annual"]

    print(f"\n{label}")
    print("Winter measurement months:")
    print("  Unique start months:", sorted(winter["MONTH_START"].unique()))
    print("  Unique end months:  ", sorted(winter["MONTH_END"].unique()))

    print("\nAnnual measurement months:")
    print("  Unique start months:", sorted(annual["MONTH_START"].unique()))
    print("  Unique end months:  ", sorted(annual["MONTH_END"].unique()))


# --- Before filtering ---
print_months(df_pmb_clean, "Before filtering")

# --- Remove unwanted months in start/end (July + December) ---
bad_months = {"07", "12"}
mask_bad_months = (df_pmb_clean["MONTH_START"].isin(bad_months)
                   | df_pmb_clean["MONTH_END"].isin(bad_months))
n_removed = mask_bad_months.sum()

df_pmb_clean = df_pmb_clean.loc[~mask_bad_months].copy()
print(
    f"\nRemoved {n_removed} rows with MONTH_START or MONTH_END in {sorted(bad_months)}."
)

# --- Correct mislabeled winter MB ---
# If MONTH_END == 06 and MB is negative, it should be annual (not winter)
mask_fix = ((df_pmb_clean["PERIOD"].str.strip().str.lower() == "winter") &
            (df_pmb_clean["MONTH_END"] == "06") &
            (df_pmb_clean["POINT_BALANCE"] < 0))
print("Rows to relabel winter -> annual:", int(mask_fix.sum()))

df_pmb_clean.loc[mask_fix, "PERIOD"] = "annual"

# --- After filtering + relabeling ---
print_months(df_pmb_clean, "After filtering + relabeling")

In [None]:
# Save to csv:
df_pmb_clean.to_csv(os.path.join(cfg.dataPath, path_PMB_WGMS_csv,
                                 'NOR_wgms_dataset_all.csv'),
                    index=False)

# Histogram of mass balance
df_pmb_clean['POINT_BALANCE'].hist(bins=20)
plt.xlabel('Mass balance [m w.e.]')

In [None]:
import folium


def plot_stakes_folium(
    df_pmb_clean,
    glacier_col="GLACIER",
    lat_col=None,
    lon_col=None,
    elev_col=None,
    id_col=None,
    center=None,
    zoom_start=10,
    color_map=None,
):
    """
    Create an interactive Folium map of stake points grouped by glacier.
    """

    # Infer column names if not provided
    if lat_col is None:
        lat_col = "lat" if "lat" in df_pmb_clean.columns else "POINT_LAT"
    if lon_col is None:
        lon_col = "lon" if "lon" in df_pmb_clean.columns else "POINT_LON"
    if elev_col is None:
        elev_col = "altitude" if "altitude" in df_pmb_clean.columns else "POINT_ELEVATION"
    if id_col is None:
        id_col = "stake_number" if "stake_number" in df_pmb_clean.columns else "POINT_ID"

    # Compute center if not provided
    if center is None:
        center_lat = float(df_pmb_clean[lat_col].median())
        center_lon = float(df_pmb_clean[lon_col].median())
    else:
        center_lat, center_lon = center

    m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start)

    # Default colors (cycled) if user doesn't give explicit mapping
    default_colors = [
        "red", "blue", "green", "purple", "orange", "darkred", "cadetblue",
        "darkgreen", "darkpurple", "pink", "gray", "black"
    ]

    glaciers = sorted(df_pmb_clean[glacier_col].dropna().unique())

    if color_map is None:
        color_map = {
            g: default_colors[i % len(default_colors)]
            for i, g in enumerate(glaciers)
        }
    else:
        # fill missing glaciers with default cycling
        for i, g in enumerate(glaciers):
            color_map.setdefault(g, default_colors[i % len(default_colors)])

    # Add markers for each glacier
    for glacier_name, df_g in df_pmb_clean.groupby(glacier_col):
        if pd.isna(glacier_name):
            continue

        fg = folium.FeatureGroup(name=str(glacier_name))
        color = color_map[str(glacier_name)]

        for _, row in df_g.iterrows():
            stake_id = row.get(id_col, "NA")
            altitude = row.get(elev_col, "NA")

            folium.CircleMarker(
                location=[row[lat_col], row[lon_col]],
                radius=5,
                color=color,
                fill=True,
                fill_color=color,
                fill_opacity=0.9,
                popup=f"{glacier_name} - Stake {stake_id}: {altitude} m",
            ).add_to(fg)

        fg.add_to(m)

    folium.LayerControl(collapsed=False).add_to(m)

    # Legend (auto-generated)
    legend_rows = "\n".join(
        f'<p><span style="color: {color_map[g]};">●</span> {g}</p>'
        for g in glaciers)

    legend_html = f"""
    <div style="
        position: fixed; bottom: 50px; left: 50px; z-index: 1000;
        background-color: white; padding: 10px; border-radius: 5px;
        border: 1px solid #999;
    ">
        <p><strong>Glaciers</strong></p>
        {legend_rows}
    </div>
    """
    m.get_root().html.add_child(folium.Element(legend_html))

    return m


m = plot_stakes_folium(df_pmb_clean, color_map=None)
m