## Setting up:

In [None]:
import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
import massbalancemachine as mbm
import pyproj
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from cmcrameri import cm
from oggm import utils

# from scripts.helpers import *
# from regions.Norway_mb.scripts.norway_preprocess import *
from regions.Svalbard.scripts.config_SVA import *

from regions.Switzerland.scripts.oggm import initialize_oggm_glacier_directories, export_oggm_grids
from regions.Switzerland.scripts.glamos import merge_pmb_with_oggm_data, rename_stakes_by_elevation, check_point_ids_contain_glacier, remove_close_points, check_multiple_rgi_ids

from regions.French_Alps.scripts.glacioclim_preprocess import add_svf_from_rgi_zarr, plot_missing_svf_for_all_glaciers, add_svf_nearest_valid

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.NorwayConfig()

mbm.utils.seed_all(cfg.seed)
mbm.utils.free_up_cuda()
mbm.plots.use_mbm_style()

## Pre-processing:

### Load WGMS data:

In [None]:
df_stakes = pd.read_csv(
    os.path.join(cfg.dataPath, 'WGMS/raw/mass_balance_point.csv'))

df_RGIId = pd.read_csv(cfg.dataPath + 'WGMS/raw/glacier.csv')

# Filter df_stakes to include only rows where country is AT or IT
df_stakes = df_stakes[df_stakes['country'].isin(['SJ'])].reset_index(drop=True)

# Create a mapping dictionary from id to rgi60_ids
id_to_rgi_map = dict(zip(df_RGIId['id'], df_RGIId['rgi60_ids']))

# Add the RGIId column to the filtered DataFrame using glacier_id instead of id
df_stakes['RGIId'] = df_stakes['glacier_id'].map(id_to_rgi_map)

# Display glacier names with NaN RGIId
display(f"Number of rows with NaN RGIId: {df_stakes['RGIId'].isna().sum()}")
display(df_stakes[df_stakes['RGIId'].isna()]['glacier_name'].unique())

# Select and rename columns
df_stakes_renamed = df_stakes.rename(
    columns={
        'point_id': 'POINT_ID',
        'latitude': 'POINT_LAT',
        'longitude': 'POINT_LON',
        'elevation': 'POINT_ELEVATION',
        'begin_date': 'FROM_DATE',
        'end_date': 'TO_DATE',
        'balance': 'POINT_BALANCE',
        'glacier_name': 'GLACIER',
        'year': 'YEAR',
        'country': 'COUNTRY',
        'balance_code': 'PERIOD'
    })

# Create new POINT_ID column
df_stakes_renamed['POINT_ID'] = (df_stakes_renamed['GLACIER'] + '_' +
                                 df_stakes_renamed['YEAR'].astype(str) + '_' +
                                 df_stakes['id'].astype(str) + '_' +
                                 df_stakes_renamed['COUNTRY'])
# Only keep relevant columns in df
df_stakes_renamed = df_stakes_renamed[[
    'POINT_ID', 'POINT_LAT', 'POINT_LON', 'POINT_ELEVATION', 'FROM_DATE',
    'TO_DATE', 'POINT_BALANCE', 'GLACIER', 'PERIOD', 'RGIId', 'YEAR',
    'begin_date_unc', 'end_date_unc'
]]

# Remove rows with NaN values in POINT_LAT, POINT_LON, and POINT_ELEVATION
df_stakes_renamed = df_stakes_renamed.dropna(
    subset=['POINT_LAT', 'POINT_LON', 'POINT_ELEVATION'])

# change date format to YYYYMMDD
df_stakes_renamed['FROM_DATE'] = df_stakes_renamed['FROM_DATE'].astype(
    str).str.replace('-', '')
df_stakes_renamed['TO_DATE'] = df_stakes_renamed['TO_DATE'].astype(
    str).str.replace('-', '')

# Add data modification column to keep track of mannual changes
df_stakes_renamed['DATA_MODIFICATION'] = ''

display(df_stakes_renamed.head(2))

In [None]:
# Check if any entry anywhere is NaN
display(df_stakes_renamed[df_stakes_renamed.isna().any(axis=1)])

In [None]:
import folium

def plot_stakes_folium(
    df_pmb_clean,
    glacier_col="GLACIER",
    lat_col=None,
    lon_col=None,
    elev_col=None,
    id_col=None,
    center=None,
    zoom_start=10,
    color_map=None,
):
    """
    Create an interactive Folium map of stake points grouped by glacier.
    """

    # Infer column names if not provided
    if lat_col is None:
        lat_col = "lat" if "lat" in df_pmb_clean.columns else "POINT_LAT"
    if lon_col is None:
        lon_col = "lon" if "lon" in df_pmb_clean.columns else "POINT_LON"
    if elev_col is None:
        elev_col = "altitude" if "altitude" in df_pmb_clean.columns else "POINT_ELEVATION"
    if id_col is None:
        id_col = "stake_number" if "stake_number" in df_pmb_clean.columns else "POINT_ID"

    # Compute center if not provided
    if center is None:
        center_lat = float(df_pmb_clean[lat_col].median())
        center_lon = float(df_pmb_clean[lon_col].median())
    else:
        center_lat, center_lon = center

    m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start)

    # Default colors (cycled) if user doesn't give explicit mapping
    default_colors = [
        "red", "blue", "green", "purple", "orange", "darkred", "cadetblue",
        "darkgreen", "darkpurple", "pink", "gray", "black"
    ]

    glaciers = sorted(df_pmb_clean[glacier_col].dropna().unique())

    if color_map is None:
        color_map = {
            g: default_colors[i % len(default_colors)]
            for i, g in enumerate(glaciers)
        }
    else:
        # fill missing glaciers with default cycling
        for i, g in enumerate(glaciers):
            color_map.setdefault(g, default_colors[i % len(default_colors)])

    # Add markers for each glacier
    for glacier_name, df_g in df_pmb_clean.groupby(glacier_col):
        if pd.isna(glacier_name):
            continue

        fg = folium.FeatureGroup(name=str(glacier_name))
        color = color_map[str(glacier_name)]

        for _, row in df_g.iterrows():
            stake_id = row.get(id_col, "NA")
            altitude = row.get(elev_col, "NA")

            folium.CircleMarker(
                location=[row[lat_col], row[lon_col]],
                radius=5,
                color=color,
                fill=True,
                fill_color=color,
                fill_opacity=0.9,
                popup=f"{glacier_name} - Stake {stake_id}: {altitude} m",
            ).add_to(fg)

        fg.add_to(m)

    folium.LayerControl(collapsed=False).add_to(m)

    # Legend (auto-generated)
    legend_rows = "\n".join(
        f'<p><span style="color: {color_map[g]};">●</span> {g}</p>'
        for g in glaciers)

    legend_html = f"""
    <div style="
        position: fixed; bottom: 50px; left: 50px; z-index: 1000;
        background-color: white; padding: 10px; border-radius: 5px;
        border: 1px solid #999;
    ">
        <p><strong>Glaciers</strong></p>
        {legend_rows}
    </div>
    """
    m.get_root().html.add_child(folium.Element(legend_html))

    return m


m = plot_stakes_folium(df_stakes_renamed, color_map=None)
m

### Add OGGM data:

In [None]:
# initialize OGGM glacier directories
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="07",
    rgi_version="62",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L1-L2_files/2025.6/elev_bands_w_data/",
    log_level='WARNING',
    task_list=None,
)

export_oggm_grids(cfg, gdirs, rgi_region="07")

unique_rgis = df_stakes_renamed['RGIId'].unique()

df_stakes_topo = merge_pmb_with_oggm_data(
    df_pmb=df_stakes_renamed,
    gdirs=gdirs,
    rgi_region="07",  # Svalbard
    rgi_version="62")

In [None]:
# Restrict to within glacier shape
df_stakes_topo = df_stakes_topo[df_stakes_topo['within_glacier_shape'] == True]
df_stakes_topo = df_stakes_topo.drop(columns=['within_glacier_shape'])

# Display rows that have any NaN values
display(df_stakes_topo[df_stakes_topo.isna().any(axis=1)])

# Drop 3 rows where consensus_ice_thickness is NaN
#df_stakes_topo_dropped = df_stakes_topo.dropna(subset=['consensus_ice_thickness'])

In [None]:
# Check for NaN
display(df_stakes_topo[df_stakes_topo.isna().any(axis=1)])

### Merge close stakes:

In [None]:
from tqdm import tqdm

df_pmb_topo = pd.DataFrame()
for gl in tqdm(df_stakes_topo.GLACIER.unique(), desc='Merging stakes'):
    print(f'-- {gl.capitalize()}:')
    df_gl = df_stakes_topo[df_stakes_topo.GLACIER == gl]
    df_gl_cleaned = remove_close_points(df_gl)
    df_pmb_topo = pd.concat([df_pmb_topo, df_gl_cleaned])
df_pmb_topo.drop(['x', 'y'], axis=1, inplace=True)
df_pmb_topo.reset_index(inplace=True, drop=True)

### Add skyview factor:

In [None]:
# Example of one svf file
rgi_id = df_pmb_topo.loc[0].RGIId

rgi_gl = "RGI60-07.01407"

# read ds with svf
path_masked_xr = os.path.join(cfg.dataPath,
                              'RGI_v6/RGI_07_Svalbard/xr_masked_grids/')

xr.open_zarr(path_masked_xr + f'{rgi_gl}.zarr').svf.plot()

In [None]:
path_masked_xr = os.path.join(cfg.dataPath,
                              "RGI_v6/RGI_07_Svalbard/xr_masked_grids")

df_pmb_topo_svf = add_svf_from_rgi_zarr(
    df_pmb_topo,
    path_masked_xr,
    rgi_col="RGIId",
    lon_col="POINT_LON",
    lat_col="POINT_LAT",
    svf_var="svf",
    out_col="svf",
)
df_missing = df_pmb_topo_svf[df_pmb_topo_svf["svf"].isna()].copy()
print("Missing SVF points:", len(df_missing))
print("Glaciers affected:", sorted(df_missing["RGIId"].unique()))

In [None]:
plot_missing_svf_for_all_glaciers(
    df_with_svf=df_pmb_topo_svf,
    path_masked_xr=path_masked_xr,
    plot_valid_points=True,
    save_dir=
    None  # or e.g. os.path.join(cfg.dataPath, "diagnostics/svf_missing")
)

### Give new stake ids:

In [None]:
df_pmb_new_ids = rename_stakes_by_elevation(df_pmb_topo_svf)

# Check the condition
check_point_ids_contain_glacier(df_pmb_new_ids)

print('Number of winter and annual samples:', len(df_pmb_new_ids))
print('Number of annual samples:',
      len(df_pmb_new_ids[df_pmb_new_ids.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_new_ids[df_pmb_new_ids.PERIOD == 'winter']))

# Histogram of mass balance
df_pmb_new_ids['POINT_BALANCE'].hist(bins=20)
plt.xlabel('Mass balance [m w.e.]')

### Final cleaning:

In [None]:
df_pmb_clean = df_pmb_new_ids.copy()

# Ensure YYYYMMDD format
df_pmb_clean["FROM_DATE"] = df_pmb_clean["FROM_DATE"].astype(str).str.zfill(8)
df_pmb_clean["TO_DATE"] = df_pmb_clean["TO_DATE"].astype(str).str.zfill(8)

# Extract months
df_pmb_clean["MONTH_START"] = df_pmb_clean["FROM_DATE"].str[4:6]
df_pmb_clean["MONTH_END"] = df_pmb_clean["TO_DATE"].str[4:6]

def print_months(df, label):
    winter = df[df.PERIOD == "winter"]
    annual = df[df.PERIOD == "annual"]

    print(f"\n{label}")
    print("Winter measurement months:")
    print("  Unique start months:", sorted(winter["MONTH_START"].unique()))
    print("  Unique end months:  ", sorted(winter["MONTH_END"].unique()))

    print("\nAnnual measurement months:")
    print("  Unique start months:", sorted(annual["MONTH_START"].unique()))
    print("  Unique end months:  ", sorted(annual["MONTH_END"].unique()))

# --- Before filtering ---
print_months(df_pmb_clean, "Before filtering")

# -----------------------
# Filtering masks (define + count BEFORE filtering)
# -----------------------

mask_winter_end_07 = (
    (df_pmb_clean["PERIOD"].astype(str).str.strip().str.lower() == "winter") &
    (df_pmb_clean["MONTH_END"] == "07")
)

# counts (on original df)
n_total_removed = int(mask_winter_end_07.sum())
n_winter_end_07 = n_total_removed

# Apply removal
df_pmb_clean = df_pmb_clean.loc[~mask_winter_end_07].copy()

# --- Correct mislabeled winter MB ---
mask_fix = (
    (df_pmb_clean["PERIOD"].astype(str).str.strip().str.lower() == "winter") &
    (df_pmb_clean["MONTH_END"] == "06") &
    (df_pmb_clean["POINT_BALANCE"] < 0)
)
n_relabel = int(mask_fix.sum())
df_pmb_clean.loc[mask_fix, "PERIOD"] = "annual"

print(
    f"\nRemoved {n_total_removed} rows total.\n"
    f"  - winter-end-07 rows removed: {n_winter_end_07}\n"
    f"Relabeled winter -> annual: {n_relabel}"
)

print_months(df_pmb_clean, "After filtering + relabeling")

In [None]:
# Save to csv:
df_pmb_clean.to_csv(os.path.join(cfg.dataPath, path_PMB_WGMS_csv,
                                 'SVA_wgms_dataset_all.csv'),
                    index=False)

# Histogram of mass balance
df_pmb_clean['POINT_BALANCE'].hist(bins=20)
plt.xlabel('Mass balance [m w.e.]')

In [None]:
import folium

def plot_stakes_folium(
    df_pmb_clean,
    glacier_col="GLACIER",
    lat_col=None,
    lon_col=None,
    elev_col=None,
    id_col=None,
    center=None,
    zoom_start=10,
    color_map=None,
):
    """
    Create an interactive Folium map of stake points grouped by glacier.
    """

    # Infer column names if not provided
    if lat_col is None:
        lat_col = "lat" if "lat" in df_pmb_clean.columns else "POINT_LAT"
    if lon_col is None:
        lon_col = "lon" if "lon" in df_pmb_clean.columns else "POINT_LON"
    if elev_col is None:
        elev_col = "altitude" if "altitude" in df_pmb_clean.columns else "POINT_ELEVATION"
    if id_col is None:
        id_col = "stake_number" if "stake_number" in df_pmb_clean.columns else "POINT_ID"

    # Compute center if not provided
    if center is None:
        center_lat = float(df_pmb_clean[lat_col].median())
        center_lon = float(df_pmb_clean[lon_col].median())
    else:
        center_lat, center_lon = center

    m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start)

    # Default colors (cycled) if user doesn't give explicit mapping
    default_colors = [
        "red", "blue", "green", "purple", "orange", "darkred", "cadetblue",
        "darkgreen", "darkpurple", "pink", "gray", "black"
    ]

    glaciers = sorted(df_pmb_clean[glacier_col].dropna().unique())

    if color_map is None:
        color_map = {
            g: default_colors[i % len(default_colors)]
            for i, g in enumerate(glaciers)
        }
    else:
        # fill missing glaciers with default cycling
        for i, g in enumerate(glaciers):
            color_map.setdefault(g, default_colors[i % len(default_colors)])

    # Add markers for each glacier
    for glacier_name, df_g in df_pmb_clean.groupby(glacier_col):
        if pd.isna(glacier_name):
            continue

        fg = folium.FeatureGroup(name=str(glacier_name))
        color = color_map[str(glacier_name)]

        for _, row in df_g.iterrows():
            stake_id = row.get(id_col, "NA")
            altitude = row.get(elev_col, "NA")

            folium.CircleMarker(
                location=[row[lat_col], row[lon_col]],
                radius=5,
                color=color,
                fill=True,
                fill_color=color,
                fill_opacity=0.9,
                popup=f"{glacier_name} - Stake {stake_id}: {altitude} m",
            ).add_to(fg)

        fg.add_to(m)

    folium.LayerControl(collapsed=False).add_to(m)

    # Legend (auto-generated)
    legend_rows = "\n".join(
        f'<p><span style="color: {color_map[g]};">●</span> {g}</p>'
        for g in glaciers)

    legend_html = f"""
    <div style="
        position: fixed; bottom: 50px; left: 50px; z-index: 1000;
        background-color: white; padding: 10px; border-radius: 5px;
        border: 1px solid #999;
    ">
        <p><strong>Glaciers</strong></p>
        {legend_rows}
    </div>
    """
    m.get_root().html.add_child(folium.Element(legend_html))

    return m


m = plot_stakes_folium(df_pmb_clean, color_map=None)
m

In [None]:
print('Number of winter and annual samples:', len(df_pmb_clean))
print('Number of annual samples:',
      len(df_pmb_clean[df_pmb_clean.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_clean[df_pmb_clean.PERIOD == 'winter']))

# Number of measurements per year:
fig, axs = plt.subplots(2, 1, figsize=(20, 15))
ax = axs.flatten()[0]
df_pmb_clean.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar',
    stacked=True,
    color=[mbm.plots.COLOR_ANNUAL, mbm.plots.COLOR_WINTER],
    ax=ax)
ax.set_title('Number of measurements per year for all glaciers')

ax = axs.flatten()[1]
num_gl = df_pmb_clean.groupby(['GLACIER']).size().sort_values()
num_gl.plot(kind='bar', ax=ax)
ax.set_title('Number of total measurements per glacier since 1951')
plt.tight_layout()