# Pre-processing of GLAMOS MB data:

Does the pre-processing of the point MB measurements from GLAMOS (winter and summer).

# Point Mass Balance data:

## Setting up:

In [None]:
import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
import matplotlib as mpl

import pandas as pd
import warnings
import massbalancemachine as mbm
from shapely.geometry import Point
import pyproj
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from cmcrameri import cm
from pathlib import Path

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.geodata import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
# Paths
path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/LSTM_two_heads')
path_save_glw

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

cmap = cm.devon

# For bars and lines:
color_diff_xgb = '#4d4d4d'

colors = get_cmap_hex(cm.batlow, 10)
color_1 = colors[0]
color_2 = '#c51b7d'

## Transform .dat files to .csv:

Transform the seasonal and winter PMB .dat files to .csv for simplicity. 

In [None]:
process_pmb_dat_files(cfg)

##  Assemble measurement periods:
### Annual measurements: 
Process annual measurements and put all stakes into one csv file

In [None]:
# Display the first two rows
df_annual_raw = process_annual_stake_data(cfg.dataPath + path_PMB_GLAMOS_csv_a)
df_annual_raw.head(2)

### Winter measurements:
For each point in annual meas., take winter meas that was taken closest:

In [None]:
process_winter_stake_data(df_annual_raw, cfg.dataPath + path_PMB_GLAMOS_csv_w,
                          cfg.dataPath + path_PMB_GLAMOS_csv_w_clean)

### Assemble both periods:

In [None]:
df_all_raw = assemble_all_stake_data(
    df_annual_raw, cfg.dataPath + path_PMB_GLAMOS_csv_w_clean,
    cfg.dataPath + path_PMB_GLAMOS_csv)

In [None]:
# Plot: Number of measurements per year
df_measurements_per_year = df_all_raw.groupby(['YEAR',
                                               'PERIOD']).size().unstack()
df_measurements_per_year.plot(kind='bar',
                              stacked=True,
                              figsize=(20, 5),
                              color=[color_1, color_2])
plt.title('Number of measurements per year for all glaciers')
plt.ylabel('Number of Measurements')
plt.xlabel('Year')
plt.legend(title='Period')
plt.tight_layout()
plt.show()

## Add RGIs Ids:

For each PMB measurement, we want to add the RGI ID (v6) of the shapefile it belongs to. 

In [None]:
df_pmb = add_rgi_ids_to_df(df_all_raw, cfg.dataPath + path_rgi_outlines)

rgiids6 = df_pmb[['GLACIER', 'RGIId']].drop_duplicates()
if check_multiple_rgi_ids(rgiids6):
    print(
        "-- Alert: The following glaciers have more than one RGIId. Cleaning up."
    )
    df_pmb_clean = clean_rgi_ids(df_pmb.copy())
    df_pmb_clean.reset_index(drop=True, inplace=True)

    rgiids6_clean = df_pmb_clean[['GLACIER', 'RGIId']].drop_duplicates()
    if check_multiple_rgi_ids(rgiids6_clean):
        print("-- Error: Some glaciers still have more than one RGIId.")
    else:
        print("-- All glaciers are correctly associated with a single RGIId.")
else:
    print("-- All glaciers are correctly associated with a single RGIId.")
    df_pmb_clean = df_pmb

## Cut from 1951:

In [None]:
# Filter to start of MS data (1951) or ERA5-Land data (1950):
df_pmb_50s = df_pmb_clean[df_pmb_clean.YEAR > 1950].sort_values(
    by=['GLACIER', 'YEAR'], ascending=[True, True])

# Change from mm w.e. to m w.e.
df_pmb_50s['POINT_BALANCE'] = df_pmb_50s['POINT_BALANCE'] / 1000

# merge ClaridenL and ClaridenU into one glacier:
df_pmb_50s.loc[df_pmb_50s.GLACIER == 'claridenU', 'GLACIER'] = 'clariden'
df_pmb_50s.loc[df_pmb_50s.GLACIER == 'claridenL', 'GLACIER'] = 'clariden'

print('Number of winter and annual samples:', len(df_pmb_50s))
print('Number of annual samples:',
      len(df_pmb_50s[df_pmb_50s.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_50s[df_pmb_50s.PERIOD == 'winter']))

# Number of measurements per year:
fig, axs = plt.subplots(2, 1, figsize=(20, 15))
ax = axs.flatten()[0]
df_pmb_50s.groupby(['YEAR',
                    'PERIOD']).size().unstack().plot(kind='bar',
                                                     stacked=True,
                                                     color=[color_1, color_2],
                                                     ax=ax)
ax.set_title('Number of measurements per year for all glaciers')

ax = axs.flatten()[1]
num_gl = df_pmb_50s.groupby(['GLACIER']).size().sort_values()
num_gl.plot(kind='bar', ax=ax)
ax.set_title('Number of total measurements per glacier since 1951')
plt.tight_layout()

### Merge stakes that are close: 
Especially with winter probes, a lot of measurements were done at the same place in the raw data and this leads to noise. We merge the stakes that are very close and keep the mean of the measurement.


In [None]:
df_pmb_50s_clean_pts = pd.DataFrame()
for gl in tqdm(df_pmb_50s.GLACIER.unique(), desc='Merging stakes'):
    print(f'-- {gl.capitalize()}:')
    df_gl = df_pmb_50s[df_pmb_50s.GLACIER == gl]
    df_gl_cleaned = remove_close_points(df_gl)
    df_pmb_50s_clean_pts = pd.concat([df_pmb_50s_clean_pts, df_gl_cleaned])
df_pmb_50s_clean_pts.drop(['x', 'y'], axis=1, inplace=True)

### Correct for wrong elevations:

In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path

import re
from pathlib import Path
from typing import Optional, Tuple

import re
from pathlib import Path
from collections import defaultdict
from typing import Iterable, Optional, Tuple

import numpy as np
import pandas as pd
import xarray as xr


def find_mismatch_by_year(
        df_gl: pd.DataFrame,
        path_xr_grids: str,
        var_name: str = "masked_elev",
        lon_name: str = "lon",
        lat_name: str = "lat",
        year_col: str = "YEAR",
        glacier_col: str = "GLACIER",
        threshold: float = 500.0,
        file_pattern: str = "{glacier}_{year}.zarr",  # pattern of your files
        strict:
    bool = False,  # if True, raise if any glacier-year file is missing
):
    """
    Compare each point's elevation to the nearest DEM cell from the glacier-year Zarr.

    Returns:
        mismatch_idx (pd.Index): indices in df_gl where |POINT_ELEVATION - DEM| >= threshold
        mismatch_df (pd.DataFrame): rows of df_gl for mismatches with columns:
                                    ['DEM_elv', 'elev_diff'] appended (sorted by elev_diff)
    Notes:
        - Expects df_gl to contain columns:
            ['POINT_LON','POINT_LAT','POINT_ELEVATION', year_col, glacier_col]
        - Zarr filenames are constructed with `file_pattern` inside `path_xr_grids`.
          Example default: f"{glacier}_{year}.zarr"
    """
    required_cols = {
        "POINT_LON", "POINT_LAT", "POINT_ELEVATION", year_col, glacier_col
    }
    missing = required_cols - set(df_gl.columns)
    if missing:
        raise KeyError(f"df_gl is missing required columns: {sorted(missing)}")

    base = Path(path_xr_grids)
    if not base.exists():
        raise FileNotFoundError(f"Directory not found: {base}")

    all_mismatch_idx = []
    mismatch_frames = []

    # Cache opened datasets per (glacier, year) to avoid re-opening
    ds_cache = {}

    # Work on current index space (no reset), so returned indices match df_gl
    grouped = df_gl.groupby([glacier_col, year_col], sort=False)

    for (glacier, year), df_group in grouped:
        # Build path for this glacier-year
        fname = file_pattern.format(glacier=str(glacier), year=int(year))
        zarr_path = base / fname

        if not zarr_path.exists():
            msg = f"Missing DEM for glacier '{glacier}', year {year}: {zarr_path}"
            if strict:
                raise FileNotFoundError(msg)
            else:
                # Skip gracefully
                # print(msg)
                continue

        # Open or reuse the dataset
        key = (glacier, int(year))
        if key not in ds_cache:
            ds_cache[key] = xr.open_zarr(str(zarr_path))
        ds = ds_cache[key]

        # Sanity checks
        if lon_name not in ds.coords or lat_name not in ds.coords:
            raise KeyError(
                f"Dataset {zarr_path} missing coords '{lon_name}'/'{lat_name}'. "
                "Rename coords or pass lon_name/lat_name correctly.")
        if var_name not in ds.variables:
            raise KeyError(
                f"Variable '{var_name}' not found in dataset {zarr_path}.")

        # Vectorized nearest sampling for this group
        lons = xr.DataArray(df_group["POINT_LON"].to_numpy(), dims="points")
        lats = xr.DataArray(df_group["POINT_LAT"].to_numpy(), dims="points")
        xr_elev_da = ds[var_name].sel({
            lon_name: lons,
            lat_name: lats
        },
                                      method="nearest")
        xr_elev = xr_elev_da.to_numpy()

        pt_elev = df_group["POINT_ELEVATION"].to_numpy()
        elev_diff = pt_elev - xr_elev

        # Assemble a small result frame aligned to original indices
        res = df_group.copy()
        res["DEM_elv"] = xr_elev
        res["elev_diff"] = elev_diff

        # Drop NaNs (on either DEM or diff)
        res = res.dropna(subset=["DEM_elv", "elev_diff"])

        # Flag mismatches (abs diff >= threshold)
        mismatch_mask = (np.abs(res["elev_diff"]) >= threshold)

        if mismatch_mask.any():
            # Collect indices and rows
            these_idx = res.index[mismatch_mask]
            all_mismatch_idx.append(these_idx)
            mismatch_frames.append(res.loc[these_idx])

    if len(all_mismatch_idx) == 0:
        # Nothing found
        return pd.Index([], dtype=df_gl.index.dtype), pd.DataFrame(
            columns=list(df_gl.columns) + ["DEM_elv", "elev_diff"])

    mismatch_idx = all_mismatch_idx[0].append(all_mismatch_idx[1:]) if len(
        all_mismatch_idx) > 1 else all_mismatch_idx[0]
    mismatch_df = pd.concat(mismatch_frames,
                            axis=0).sort_values(by="elev_diff", ascending=True)

    return mismatch_idx, mismatch_df


def reconcile_points_by_year(
    df: pd.DataFrame,
    path_xr_grids: str,
    var_name: str = "masked_elev",
    lon_name: str = "lon",
    lat_name: str = "lat",
    year_col: str = "YEAR",
    glacier_col: str = "GLACIER",
    point_elev_col: str = "POINT_ELEVATION",
    threshold: float = 500.0,
    file_pattern: str = "{glacier}_{year}.zarr",
    replace_glaciers: Optional[Iterable[str]] = None,  # e.g., {"aletsch"}
    strict: bool = False,
    verbose: bool = True,
):
    """
    Reconcile point elevations against glacier-year DEMs stored as Zarrs.

    For each (GLACIER, YEAR) group:
      - Open <path_xr_grids>/<glacier>_<year>.zarr if present.
      - If missing, fall back to the earliest available <glacier>_YYYY.zarr in that folder.
      - Sample nearest DEM cell at (POINT_LON, POINT_LAT).
      - If |POINT_ELEVATION - DEM| >= threshold:
          * If GLACIER ∈ replace_glaciers: REPLACE POINT_ELEVATION with DEM value.
          * Else: DROP the row.

    Prints per-glacier counts of dropped/replaced points and fallback uses.

    Returns:
      df_clean    : DataFrame after drops/replacements.
      df_mismatch : All mismatching rows with ['DEM_elv','elev_diff'] added.
      summary     : Per-glacier counts (dropped, replaced, fallback_groups_used, missing_dem_groups).
    """
    # --- validations ---
    required = {
        glacier_col, year_col, "POINT_LON", "POINT_LAT", point_elev_col
    }
    missing = required - set(df.columns)
    if missing:
        raise KeyError(f"Input df is missing columns: {sorted(missing)}")

    base = Path(path_xr_grids)
    if not base.exists():
        raise FileNotFoundError(f"Directory not found: {base}")

    replace_set = set(g.lower() for g in (replace_glaciers or set()))

    # Work on a unique-index copy so drops/replacements are precise
    df_clean = df.reset_index(drop=True).copy()

    # Caches and accumulators
    ds_cache: dict[Tuple[str, int], xr.Dataset] = {}
    mismatch_frames = []
    dropped_indices = []
    drop_counts = defaultdict(int)
    replace_counts = defaultdict(int)
    missing_dem_groups = defaultdict(int)
    fallback_counts = defaultdict(int)

    def _find_existing_dem_path(
            base_dir: Path, glacier_name: str, requested_year: int,
            patt: str) -> Tuple[Optional[Path], Optional[int], bool]:
        """
        Return (path, used_year, used_fallback).
          - Exact match -> (exact_path, requested_year, False)
          - Else earliest glacier_YYYY.zarr -> (fallback_path, year, True)
          - Else -> (None, None, False)
        """
        exact_name = patt.format(glacier=str(glacier_name),
                                 year=int(requested_year))
        exact_path = base_dir / exact_name
        if exact_path.exists():
            return exact_path, int(requested_year), False

        rgx = re.compile(rf"^{re.escape(str(glacier_name))}_(\d{{4}})\.zarr$",
                         re.IGNORECASE)
        candidates = []
        for entry in base_dir.iterdir():
            if not entry.name.lower().endswith(".zarr"):
                continue
            m = rgx.match(entry.name)
            if not m:
                continue
            try:
                y = int(m.group(1))
                candidates.append((y, entry))
            except ValueError:
                continue
        if not candidates:
            return None, None, False
        candidates.sort(key=lambda t: t[0])  # earliest year first
        y_min, path_min = candidates[0]
        return path_min, y_min, True

    # Iterate by (GLACIER, YEAR) groups; indices remain aligned with df_clean
    for (glacier, year), grp in df_clean.groupby([glacier_col, year_col],
                                                 sort=False):
        # Resolve DEM path with fallback if needed
        zarr_path, used_year, used_fallback = _find_existing_dem_path(
            base_dir=base,
            glacier_name=str(glacier),
            requested_year=int(year),
            patt=file_pattern)

        if zarr_path is None:
            if verbose:
                print(
                    f"[WARN] No DEMs found at all for glacier='{glacier}' in {base}"
                )
            missing_dem_groups[str(glacier)] += 1
            if strict:
                raise FileNotFoundError(
                    f"Missing DEM for glacier '{glacier}' (no files like {glacier}_YYYY.zarr)"
                )
            continue

        if used_fallback and verbose:
            print(
                f"[INFO] Fallback DEM for glacier='{glacier}': requested {year} -> using {used_year} ({zarr_path.name})"
            )
            fallback_counts[str(glacier)] += 1

        key = (str(glacier), int(used_year))
        if key not in ds_cache:
            ds_cache[key] = xr.open_zarr(str(zarr_path))
        ds = ds_cache[key]

        # sanity checks
        if lon_name not in ds.coords or lat_name not in ds.coords:
            raise KeyError(
                f"{zarr_path} missing coords '{lon_name}'/'{lat_name}'. "
                "Rename dataset coords or pass correct names.")
        if var_name not in ds.variables:
            raise KeyError(f"Variable '{var_name}' not found in {zarr_path}.")

        # vectorized nearest sampling
        lons = xr.DataArray(grp["POINT_LON"].to_numpy(), dims="points")
        lats = xr.DataArray(grp["POINT_LAT"].to_numpy(), dims="points")
        xr_elev = ds[var_name].sel({
            lon_name: lons,
            lat_name: lats
        },
                                   method="nearest").to_numpy()

        pt_elev = grp[point_elev_col].to_numpy()
        elev_diff = pt_elev - xr_elev

        res = grp.copy()
        res["DEM_elv"] = xr_elev
        res["elev_diff"] = elev_diff

        # Remove rows where DEM or diff is NaN before thresholding
        res = res.dropna(subset=["DEM_elv", "elev_diff"])

        # mismatches for this (glacier, year)
        mask = (np.abs(res["elev_diff"]) >= threshold)
        if not mask.any():
            continue

        # rows to act on (indices in df_clean)
        mm = res.loc[mask]
        mismatch_frames.append(mm)

        gkey = str(glacier)
        if gkey.lower() in replace_set:
            # Replace POINT_ELEVATION with DEM value for those rows
            df_clean.loc[mm.index, point_elev_col] = mm["DEM_elv"].values
            replace_counts[gkey] += len(mm)
        else:
            # Drop mismatched rows
            dropped_indices.extend(mm.index.tolist())
            drop_counts[gkey] += len(mm)

    # Apply drops once
    if dropped_indices:
        df_clean = df_clean.drop(index=dropped_indices)

    # Build mismatch table
    if mismatch_frames:
        df_mismatch = pd.concat(mismatch_frames,
                                axis=0).sort_values("elev_diff")
    else:
        df_mismatch = pd.DataFrame(columns=list(df.columns) +
                                   ["DEM_elv", "elev_diff"])

    # Tidy
    df_clean = df_clean.sort_index().reset_index(drop=True)

    # Summary dataframe
    glaciers = sorted(
        set(
            list(drop_counts.keys()) + list(replace_counts.keys()) +
            list(missing_dem_groups.keys()) + list(fallback_counts.keys())))
    summary = pd.DataFrame({
        "GLACIER":
        glaciers,
        "dropped": [drop_counts[g] for g in glaciers],
        "replaced": [replace_counts[g] for g in glaciers],
        "fallback_groups_used": [fallback_counts[g] for g in glaciers],
        "missing_dem_groups": [missing_dem_groups[g] for g in glaciers],
    })

    # Print per-glacier info
    if verbose and len(summary):
        print("\n=== Reconcile summary (per glacier) ===")
        for _, row in summary.iterrows():
            g = row["GLACIER"]
            d = int(row["dropped"])
            r = int(row["replaced"])
            f = int(row["fallback_groups_used"])
            m = int(row["missing_dem_groups"])
            msg = f"{g}: removed {d} point(s)"
            if r: msg += f", replaced {r} point(s)"
            if f: msg += f", fallback DEM groups: {f}"
            if m: msg += f", missing DEM groups: {m}"
            print(msg)

    return df_clean, df_mismatch, summary


def first_year_per_glacier(path_xr_grids: str) -> pd.DataFrame:
    """
    Scan a directory for Zarr folders named <glacier>_<year>.zarr and
    return a DataFrame with the earliest year per glacier and the path.

    Returns columns:
      - glacier
      - first_year (int)
      - first_year_path (str)
    """
    p = Path(path_xr_grids)
    if not p.exists():
        raise FileNotFoundError(f"Directory not found: {p}")

    # Match anything up to the last underscore, then 4-digit year, then .zarr
    # e.g. "aletsch_1951.zarr", "some_glacier_name_2008.zarr"
    pat = re.compile(r"^(?P<glacier>.+)_(?P<year>\d{4})\.zarr$", re.IGNORECASE)

    # Keep min year and path per glacier
    best = {}  # glacier -> (year, full_path)
    for entry in p.iterdir():
        # Zarr datasets are often directories; accept both dir and file if present
        if not entry.name.lower().endswith(".zarr"):
            continue
        m = pat.match(entry.name)
        if not m:
            continue
        glacier = m.group("glacier")
        try:
            year = int(m.group("year"))
        except ValueError:
            continue

        if glacier not in best or year < best[glacier][0]:
            best[glacier] = (year, str(entry.resolve()))

    if not best:
        # No valid matches found
        return pd.DataFrame(
            columns=["glacier", "first_year", "first_year_path"])

    rows = [{
        "glacier": g,
        "first_year": y,
        "first_year_path": path
    } for g, (y, path) in best.items()]
    df = pd.DataFrame(rows).sort_values(["glacier",
                                         "first_year"]).reset_index(drop=True)
    return df

In [None]:
# Make a unique-index working copy
df_clean = df_pmb_50s_clean_pts.reset_index(drop=True).copy()
print("Initial number of rows:", len(df_clean))

path_xr_grids = os.path.join(cfg.dataPath, path_GLAMOS_topo,
                             "xr_masked_grids/")

df_clean, df_mismatch, summary = reconcile_points_by_year(
    df=df_pmb_50s_clean_pts,
    path_xr_grids=path_xr_grids,
    var_name="masked_elev",
    lon_name="lon",
    lat_name="lat",
    year_col="YEAR",
    glacier_col="GLACIER",
    point_elev_col="POINT_ELEVATION",
    threshold=400.0,
    file_pattern="{glacier}_{year}.zarr",
    replace_glaciers={"aletsch"},  # replace for Aletsch, drop for others
    strict=False,
    verbose=True,  # prints counts per glacier
)

print("Final number of rows:", len(df_clean))

# Save mismatches to CSV
out_csv = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                       "GLAMOS_elev_mismatch.csv")
df_mismatch.sort_values(by="elev_diff", ascending=False, inplace=True)
df_mismatch.to_csv(out_csv, index=False)
print("Saved mismatches to:", out_csv)

# df_clean is your final cleaned dataframe (all glaciers, mismatches removed)
df_pmb_50s_clean_elv = df_clean

# reset_index
df_pmb_50s_clean_elv.reset_index(drop=True, inplace=True)

# Save intermediate output
print('Saving intermediate output df_pmb_50s.csv to {path_PMB_GLAMOS_csv}')
df_pmb_50s_clean_elv.to_csv(os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                                         'df_pmb_50s.csv'),
                            index=False)
df_pmb_50s_clean_elv[[
    'GLACIER', 'POINT_ID', 'POINT_LAT', 'POINT_LON', 'PERIOD'
]].to_csv(os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                       'coordinate_50s.csv'),
          index=False)

In [None]:
glacier_name = 'rhone'
df_clean = df_pmb_50s_clean_pts.reset_index(drop=True).copy()
df_first = first_year_per_glacier(path_xr_grids)

df_gl = df_clean[(df_clean.GLACIER == glacier_name)]
ds = xr.open_zarr(
    df_first[df_first.glacier == glacier_name].first_year_path.values[0])

threshold = 400.0,
mismatch_idx, df_with_diffs = find_mismatch_by_year(
    df_gl=df_gl,  # must include GLACIER and YEAR columns
    path_xr_grids=path_xr_grids,
    var_name="masked_elev",
    lon_name="lon",
    lat_name="lat",
    year_col="YEAR",  # change if your year column is named differently
    glacier_col="GLACIER",
    threshold=threshold,  # meters
    file_pattern="{glacier}_{year}.zarr",  # e.g. "aletsch_1951.zarr"
    strict=False,
)
print(
    f"Number of POINT indices with >={threshold} m mismatch: {len(mismatch_idx)}"
)


def pick_ann_file(cfg, glacier_name, year, period="annual"):
    if period == "annual":
        suffix = "ann"
    elif period == "winter":
        suffix = "win"
    base = os.path.join(cfg.dataPath, path_distributed_MB_glamos, "GLAMOS",
                        glacier_name)
    cand_lv95 = os.path.join(base, f"{year}_{suffix}_fix_lv95.grid")
    cand_lv03 = os.path.join(base, f"{year}_{suffix}_fix_lv03.grid")
    if os.path.exists(cand_lv95):
        return cand_lv95, "lv95"
    if os.path.exists(cand_lv03):
        return cand_lv03, "lv03"
    return None, None


glacier_name = glacier_name
year = df_first[df_first.glacier == glacier_name].first_year.values[0]
period = 'annual'
file_ann, coord_system = pick_ann_file(cfg, glacier_name, year, period)
grid_path_ann = os.path.join(cfg.dataPath, path_distributed_MB_glamos,
                             "GLAMOS", glacier_name, file_ann)

# Load GLAMOS data and convert to WGS84
metadata_ann, grid_data_ann = load_grid_file(grid_path_ann)
ds_glamos_ann = convert_to_xarray_geodata(grid_data_ann, metadata_ann)
if coord_system == "lv03":
    ds_glamos_wgs84_ann = transform_xarray_coords_lv03_to_wgs84(ds_glamos_ann)
elif coord_system == "lv95":
    ds_glamos_wgs84_ann = transform_xarray_coords_lv95_to_wgs84(ds_glamos_ann)

figure = plt.figure(figsize=(20, 6))

# Shared normalization across both plots
vmin = min(df_with_diffs["POINT_ELEVATION"].min(), float(ds.masked_elev.min()))
vmax = max(df_with_diffs["POINT_ELEVATION"].max(), float(ds.masked_elev.max()))
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
cmap = plt.cm.terrain

# ---- First subplot ----
ax1 = plt.subplot(1, 2, 1)
ds_glamos_wgs84_ann.plot.imshow(
    ax=ax1,
    cmap="Greys",
    cbar_kwargs={"label": "Mass Balance [m w.e.]"},
)

# scatter using same cmap + norm
sc = ax1.scatter(
    df_with_diffs["POINT_LON"],
    df_with_diffs["POINT_LAT"],
    c=df_with_diffs["POINT_ELEVATION"],
    cmap=cmap,
    norm=norm,
    s=25,
)
ax1.set_title(f"{glacier_name.capitalize()} {year} GLAMOS glacier-wide MB")

# ---- Second subplot ----
ax2 = plt.subplot(1, 2, 2)
im = ds.masked_elev.plot(
    ax=ax2,
    cmap=cmap,
    norm=norm,
    add_colorbar=False,  # don’t add duplicate colorbar
)
ax2.set_title(f"{glacier_name.capitalize()} {year} DEM")

# ---- Shared colorbar for elevation ----
cbar = figure.colorbar(
    mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
    ax=ax2,
    orientation="vertical",
    fraction=0.02,
    pad=0.02,
)
cbar.set_label("Elevation [m a.s.l.]")

plt.tight_layout()

df_with_diffs.head(10)

### Barplots:

In [None]:
# Number of measurements per year:
fig, axs = plt.subplots(2, 1, figsize=(20, 15))
ax = axs.flatten()[0]
df_pmb_50s_clean_elv.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar', stacked=True, color=[color_1, color_2], ax=ax)
ax.set_title('Number of measurements per year for all glaciers')

ax = axs.flatten()[1]
num_gl = df_pmb_50s_clean_elv.groupby(['GLACIER']).size().sort_values()
num_gl.plot(kind='bar', ax=ax)
ax.set_title('Number of total measurements per glacier since 1951')
plt.tight_layout()

In [None]:
glacier_list = list(df_pmb_50s_clean_elv.GLACIER.unique())
print('Number of glaciers:', len(glacier_list))
glacier_list.sort()
glacier_list

In [None]:
# Number of measurements per glacier per year:
num_gl_yr = df_pmb_50s_clean_elv.groupby(['GLACIER', 'YEAR', 'PERIOD'
                                          ]).size().unstack().reset_index()

num_gl_annual = df_pmb_50s_clean_elv[
    df_pmb_50s_clean_elv.PERIOD == 'annual'].groupby(['GLACIER'
                                                      ]).size().sort_values()

# Plot one glacier per column:
big_gl = num_gl_annual[num_gl_annual > 250].index.sort_values()
num_glaciers = len(big_gl)
fig, ax = plt.subplots(num_glaciers, 1, figsize=(15, 5 * num_glaciers))
for i, gl in enumerate(big_gl):
    num_gl_yr[num_gl_yr.GLACIER == gl].plot(x='YEAR',
                                            kind='bar',
                                            stacked=True,
                                            ax=ax[i],
                                            title=gl)
    ax[i].set_ylabel('Number of measurements')
    ax[i].set_title

In [None]:
print('Number of winter and annual samples:', len(df_pmb_50s_clean_elv))
print('Number of annual samples:',
      len(df_pmb_50s_clean_elv[df_pmb_50s_clean_elv.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_50s_clean_elv[df_pmb_50s_clean_elv.PERIOD == 'winter']))
# Unique glaciers, sorted
glacier_list = sorted(df_pmb_50s_clean_elv.GLACIER.unique())
print(f"Number of glaciers: {len(glacier_list)}")
print(f"Glaciers: {glacier_list}")

## Add topographical information from OGGM & SGI:

### Skyview:

In [None]:
# paths
path_svf_latlon = os.path.join(cfg.dataPath, "GLAMOS/topo/RGI_v6_11",
                               "svf_nc_latlon")


def sample_svf_for_points(df_points: pd.DataFrame) -> pd.DataFrame:
    """
    For each row in df_points (must contain ['POINT_LAT','POINT_LON','RGIId']),
    open the corresponding *_svf_latlon.nc and sample SVF/ASVF/OPNS at the point.
    Returns a copy with new columns: 'SVF','ASVF','OPNS'.
    """
    out = df_points.copy()
    out["SVF"] = np.nan
    out["ASVF"] = np.nan
    out["OPNS"] = np.nan

    # group by glacier to open each dataset once
    for rgi_id, df_g in out.groupby("RGIId", sort=False):
        svf_path = os.path.join(path_svf_latlon, f"{rgi_id}_svf_latlon.nc")
        if not os.path.exists(svf_path):
            # no SVF available for this glacier; leave NaNs
            print(f"[warn] Missing SVF file for {rgi_id}: {svf_path}")
            continue

        # open & prep once per glacier
        ds = xr.open_dataset(svf_path)

        # ensure expected var names exist
        vars_present = [
            v for v in ["svf", "asvf", "opns"] if v in ds.data_vars
        ]
        if not vars_present:
            print(f"[warn] No SVF variables in {svf_path}")
            ds.close()
            continue

        # make sure coords are named 'lat'/'lon' and sorted ascending (required by interp)
        ren = {}
        if "x" in ds.dims or "y" in ds.dims:
            ren.update({"x": "lon", "y": "lat"})
        if "latitude" in ds.dims or "longitude" in ds.dims:
            ren.update({"longitude": "lon", "latitude": "lat"})
        if ren:
            ds = ds.rename(ren)
        if ds.lon[0] > ds.lon[-1]:
            ds = ds.sortby("lon")
        if ds.lat[0] > ds.lat[-1]:
            ds = ds.sortby("lat")

        # vectorized sampling points
        lats = xr.DataArray(df_g["POINT_LAT"].values, dims="p")
        lons = xr.DataArray(df_g["POINT_LON"].values, dims="p")

        # clip to grid bounds to avoid all-NaNs near edges
        lons = xr.apply_ufunc(np.clip, lons, ds.lon.values.min(),
                              ds.lon.values.max())
        lats = xr.apply_ufunc(np.clip, lats, ds.lat.values.min(),
                              ds.lat.values.max())

        # choose method: "nearest" (robust) or "linear" (smooth)
        # nearest avoids NaNs if points slightly off-grid due to reprojection
        sampled = {}
        if "svf" in ds:
            sampled["SVF"] = ds["svf"].interp(lon=lons,
                                              lat=lats,
                                              method="nearest").values
        if "asvf" in ds:
            sampled["ASVF"] = ds["asvf"].interp(lon=lons,
                                                lat=lats,
                                                method="nearest").values
        if "opns" in ds:
            sampled["OPNS"] = ds["opns"].interp(lon=lons,
                                                lat=lats,
                                                method="nearest").values

        # write back into the corresponding rows
        out.loc[df_g.index, "SVF"] = sampled.get("SVF", np.nan)
        out.loc[df_g.index, "ASVF"] = sampled.get("ASVF", np.nan)
        out.loc[df_g.index, "OPNS"] = sampled.get("OPNS", np.nan)

        ds.close()

    return out

In [None]:
# initialize OGGM glacier directories
df_pmb_50s_clean = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                               'df_pmb_50s.csv')
df_with_svf = sample_svf_for_points(df_pmb_50s_clean)

### OGGM data:

In [None]:
# # initialize OGGM glacier directories
# df_pmb_50s_clean = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
#                                'df_pmb_50s.csv')
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="11",
    rgi_version="6",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L3-L5_files/2023.1/elev_bands/W5E5_w_data/",
    log_level='WARNING',
    task_list=None,
    from_prepro_level=3,
    prepro_border=10,
)
unique_rgis = df_with_svf['RGIId'].unique()

export_oggm_grids(cfg, gdirs)

df_pmb_topo = merge_pmb_with_oggm_data(
    df_pmb=df_with_svf,
    gdirs=gdirs,
    rgi_region="11",
    rgi_version="6",
)

In [None]:
# restrict to within glacier shape
df_pmb_topo = df_pmb_topo[df_pmb_topo['within_glacier_shape']]
df_pmb_topo = df_pmb_topo.drop(columns=['within_glacier_shape'])

print('Number of winter and annual samples:', len(df_pmb_topo))
print('Number of annual samples:',
      len(df_pmb_topo[df_pmb_topo.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_topo[df_pmb_topo.PERIOD == 'winter']))
# Unique glaciers, sorted
glacier_list = sorted(df_pmb_topo.GLACIER.unique())
print(f"Number of glaciers: {len(glacier_list)}")
print(f"Glaciers: {glacier_list}")

### SGI data:

In [None]:
# First create the masked topographical arrays per glacier:
glacier_list = sorted(df_pmb_topo.GLACIER.unique())
create_sgi_topo_masks(cfg,
                      glacier_list,
                      type='glacier_name',
                      path_save=os.path.join(cfg.dataPath, path_SGI_topo,
                                             'xr_masked_grids/'))

In [None]:
# Example
i = 0
glacier_name = 'clariden'
df_pmb_gl = df_pmb_50s_clean[df_pmb_50s_clean.GLACIER == glacier_name]

stake_coordinates = df_pmb_gl[['POINT_LON', 'POINT_LAT']].values

# Open SGI grid:
ds_sgi = xr.open_dataset(cfg.dataPath + path_SGI_topo + 'xr_masked_grids/' +
                         f'{glacier_name}.zarr')

# Plot the masked data
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
ds_sgi.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted')
ds_sgi.masked_slope.plot(ax=axs[1], cmap='cividis')
ds_sgi.masked_elev.plot(ax=axs[2], cmap='terrain')
ds_sgi.glacier_mask.plot(ax=axs[3], cmap='binary')
axs[3].scatter(stake_coordinates[:, 0], stake_coordinates[:, 1], c='r', s=10)
axs[0].set_title("Aspect")
axs[1].set_title("Slope")
axs[2].set_title("DEM")
axs[3].set_title("Glacier mask")
plt.tight_layout()

In [None]:
path_masked_grids = os.path.join(cfg.dataPath, path_SGI_topo,
                                 'xr_masked_grids/')

# Merge PMB with SGI data
df_pmb_sgi = merge_pmb_with_sgi_data(
    df_pmb_topo,  # cleaned PMB DataFrame
    path_masked_grids,  # path to SGI grids
    voi=["masked_aspect", "masked_slope", "masked_elev"])

# Drop points that have no intersection with SGI mask: (have NaN values)
df_pmb_sgi = df_pmb_sgi.dropna()

In [None]:
# Count and display the number of samples
print(f"Total number of winter and annual samples: {len(df_pmb_sgi)}")

# Count occurrences of 'PERIOD' values
period_counts = df_pmb_sgi['PERIOD'].value_counts()
print(f"Number of annual samples: {period_counts.get('annual', 0)}")
print(f"Number of winter samples: {period_counts.get('winter', 0)}")

# Unique years, sorted
unique_years = np.sort(df_pmb_sgi.YEAR.unique())
print(f"Unique years: {unique_years}")

# Unique glaciers, sorted
glacier_list = sorted(df_pmb_sgi.GLACIER.unique())
print(f"Number of glaciers: {len(glacier_list)}")
print(f"Glaciers: {glacier_list}")

In [None]:
# Example:
glacierName = 'clariden'
# stakes
df_stakes = df_pmb_topo.copy()
df_stakes = df_stakes[(df_stakes['GLACIER'] == glacierName)]
RGIId = df_stakes.RGIId.unique()[0]
print(RGIId)
# open OGGM xr for glacier
# Get oggm data for that RGI grid
ds_oggm = xr.open_dataset(f'{cfg.dataPath}/OGGM/xr_grids/{RGIId}.zarr')

# Define the coordinate transformation
transf = pyproj.Transformer.from_proj(
    pyproj.CRS.from_user_input("EPSG:4326"),  # Input CRS (WGS84)
    pyproj.CRS.from_user_input(ds_oggm.pyproj_srs),  # Output CRS from dataset
    always_xy=True)

# Transform all coordinates in the group
lon, lat = df_stakes["POINT_LON"].values, df_stakes["POINT_LAT"].values
x_stake, y_stake = transf.transform(lon, lat)
df_stakes['x'] = x_stake
df_stakes['y'] = y_stake

# plot stakes
plt.figure(figsize=(10, 5))
ax = plt.subplot(121)
ds_oggm.glacier_mask.plot(cmap='binary', ax=ax)
sns.scatterplot(
    df_stakes,
    x='x',
    y='y',
    # hue='within_glacier_shape',
    ax=ax,
    palette=['r', 'b'])
ax.set_title('Stakes on glacier OGGM')

ax = plt.subplot(122)
path_SGI_topo = f'{cfg.dataPath}/GLAMOS/topo/SGI2020/'
sgi_grid = xr.open_dataset(path_SGI_topo +
                           f'xr_masked_grids/{glacierName}.zarr')
sgi_grid.glacier_mask.plot(cmap='binary', ax=ax)
sns.scatterplot(
    df_stakes,
    x='POINT_LON',
    y='POINT_LAT',
    # hue='within_glacier_shape',
    ax=ax,
    palette=['r', 'b'])
ax.set_title('Stakes on glacier SGI')

In [None]:
# Number of measurements per year:
fig, axs = plt.subplots(2, 1, figsize=(20, 15))
ax = axs.flatten()[0]
df_pmb_sgi.groupby(['YEAR',
                    'PERIOD']).size().unstack().plot(kind='bar',
                                                     stacked=True,
                                                     color=[color_1, color_2],
                                                     ax=ax)
ax.set_title('Number of measurements per year for all glaciers')

ax = axs.flatten()[1]
num_gl = df_pmb_sgi.groupby(['GLACIER']).size().sort_values()
num_gl.plot(kind='bar', ax=ax)
ax.set_title('Number of total measurements per glacier since 1951')
plt.tight_layout()

### Example:


In [None]:
glacierName = 'clariden'
df_pmb_gl = df_pmb_sgi[(df_pmb_sgi.GLACIER == glacierName)]

# Plot aspect and sgi aspect
fig, axs = plt.subplots(1, 3, figsize=(15, 6))
axs[0].scatter(df_pmb_gl.aspect, df_pmb_gl.aspect_sgi)
axs[0].set_xlabel('aspect oggm')
axs[0].set_ylabel('aspect sgi')
axs[0].set_title('Aspect')

axs[1].scatter(df_pmb_gl.slope, df_pmb_gl.slope_sgi)
axs[1].set_xlabel('slope oggm')
axs[1].set_ylabel('slope sgi')
axs[1].set_title('Slope')

# same for topo
axs[2].scatter(df_pmb_gl.topo, df_pmb_gl.topo_sgi)
axs[2].set_xlabel('topo oggm')
axs[2].set_ylabel('topo sgi')
axs[2].set_title('Topo')
# add 1:1 line
for ax in axs:
    ax.plot(ax.get_xlim(), ax.get_xlim(), ls="--", c=".3")

plt.tight_layout()

## Give new stake IDs:
Give new stake IDs with glacier name and then a number according to the elevation. This is because accross glaciers some stakes have the same ID which is not practical.

In [None]:
# drop taelliboden (only one measurement)
df_pmb_sgi = df_pmb_sgi[df_pmb_sgi.GLACIER != 'taelliboden']

# drop taelliboden (big outlier)
df_pmb_sgi = df_pmb_sgi[df_pmb_sgi.GLACIER != 'plainemorte']

df_pmb_sgi = rename_stakes_by_elevation(df_pmb_sgi)

# Check the condition
check_point_ids_contain_glacier(df_pmb_sgi)

# Save to CSV
fname = 'CH_wgms_dataset_all.csv'
df_pmb_sgi.to_csv(os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv, fname),
                  index=False)
log.info(f"-- Saved pmb & oggm dataset {fname} to: {path_PMB_GLAMOS_csv}")

print('Number of winter and annual samples:', len(df_pmb_sgi))
print('Number of annual samples:',
      len(df_pmb_sgi[df_pmb_sgi.PERIOD == 'annual']))
print('Number of winter samples:',
      len(df_pmb_sgi[df_pmb_sgi.PERIOD == 'winter']))

# Histogram of mass balance
df_pmb_sgi['POINT_BALANCE'].hist(bins=20)
plt.xlabel('Mass balance [m w.e.]')

## Final cleaning:

In [None]:
df_pmb_sgi['MONTH_START'] = [str(date)[4:6] for date in df_pmb_sgi.FROM_DATE]
df_pmb_sgi['MONTH_END'] = [str(date)[4:6] for date in df_pmb_sgi.TO_DATE]

# drop rows where month_start is '07'
df_pmb_sgi = df_pmb_sgi[df_pmb_sgi['MONTH_START'] != '07']

# drop
df_pmb_sgi = df_pmb_sgi.loc[~((df_pmb_sgi['MONTH_END'] == '06') &
                              (df_pmb_sgi['PERIOD'] == 'annual'))]

df_pmb_sgi = df_pmb_sgi.loc[~((df_pmb_sgi['MONTH_END'] == '11') &
                              (df_pmb_sgi['PERIOD'] == 'annual'))]

# Rows where month_end is '05' and period is 'annual', rename period to 'winter'
df_pmb_sgi.loc[(df_pmb_sgi['MONTH_END'] == '05') &
               (df_pmb_sgi['PERIOD'] == 'annual'), 'PERIOD'] = 'winter'

# Rows where month_end is '08' and period is 'winter', rename period to 'winter'
df_pmb_sgi.loc[(df_pmb_sgi['MONTH_END'] == '08') &
               (df_pmb_sgi['PERIOD'] == 'winter'), 'PERIOD'] = 'annual'

# Save to csv:
df_pmb_sgi.to_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                  f'CH_wgms_dataset_all.csv',
                  index=False)


In [None]:
df = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                 f'CH_wgms_dataset_all.csv')
df.GLACIER.unique()

# Glacier wide MB:
Pre-processing of glacier wide SMB data from GLAMOS. Transform .dat files to .csv. 

In [None]:
process_SMB_GLAMOS(cfg)

In [None]:
# Obs: no fixed dates, but using observed periods.
# Example:
fileName = 'aletsch_obs.csv'
aletsch_csv = pd.read_csv(cfg.dataPath + path_SMB_GLAMOS_csv + 'obs/' +
                          fileName,
                          sep=',',
                          header=0,
                          encoding='latin-1')
aletsch_csv.head(2)

In [None]:
# Fix: with fixed periods (hydrological year).
# # Example:
fileName = 'aletsch_fix.csv'
aletsch_csv = pd.read_csv(cfg.dataPath + path_SMB_GLAMOS_csv + 'fix/' +
                          fileName,
                          sep=',',
                          header=0,
                          encoding='latin-1')
aletsch_csv.head(2)

# Potential incoming clear sky solar radiation:

Pre-process glamos data of "potential incoming clear sky solar radiation (pcsr)" used as a topographical variable. One per day grid per glacier for one year only, depends on the glacier.

In [None]:
RUN = False
if RUN:
    glDirect = np.sort(os.listdir(cfg.dataPath + path_pcsr +
                                  'raw/'))  # Glaciers with data

    print('Number of glacier with clear sky radiation data:', len(glDirect))
    print('Glaciers with clear sky radiation data:', glDirect)

    process_pcsr(cfg)

In [None]:
# read an plot one file
xr_file = xr.open_dataset(cfg.dataPath + path_pcsr + 'zarr/' +
                          'xr_direct_aletsch.zarr')
xr_file['grid_data'].plot(x='x', y='y', col='time', col_wrap=3)

In [None]:
pcsr_glaciers = os.listdir(cfg.dataPath + path_pcsr + 'raw/')
len(pcsr_glaciers)

In [None]:
# years available per glacier
geod_glaciers = [
    'schwarzbach', 'joeri', 'sanktanna', 'corvatsch', 'sexrouge', 'murtel',
    'plattalva', 'tortin', 'basodino', 'limmern', 'adler', 'hohlaub',
    'albigna', 'tsanfleuron', 'silvretta', 'oberaar', 'gries', 'clariden',
    'gietro', 'schwarzberg', 'forno', 'allalin', 'otemma', 'findelen', 'rhone',
    'morteratsch', 'corbassiere', 'gorner', 'aletsch'
]

base_dir = os.path.join(cfg.dataPath, path_pcsr, 'raw')

glacier_years = {}

for glacier_name in geod_glaciers:
    glacier_path = os.path.join(base_dir, glacier_name)
    if os.path.isdir(glacier_path):
        years = []
        for fname in os.listdir(glacier_path):
            match = re.search(r'(\d{4})', fname)  # look for a 4-digit year
            if match:
                years.append(int(match.group(1)))
        glacier_years[glacier_name] = sorted(set(years))

pd.DataFrame(glacier_years).transpose().sort_values(by=0).reset_index().rename(
    columns={
        'index': 'glacier_name',
        0: 'pcsr year'
    }).to_csv('pcsr.csv')

In [None]:
pd.DataFrame(glacier_years).transpose().sort_values(by=0).reset_index().rename(
    columns={
        'index': 'glacier_name',
        0: 'pcsr year'
    })