# Glacier grids from RGI:

Creates monthly grid files for the MBM to make PMB predictions over the whole glacier grid. The files come from the RGI grid with OGGM topography. Computing takes a long time because of the conversion to monthly format.
## Setting up:

In [None]:
# --- System & utilities ---
import os
import sys
import re
import csv
import ast
import math
import traceback
import itertools
import random
import pickle
import logging
import warnings
from datetime import datetime
from functools import partial
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed

# Add repo root for MBM imports
sys.path.append(os.path.join(os.getcwd(), "../../"))

# --- Data science stack ---
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from cmcrameri import cm
import geopandas as gpd

# --- Machine learning / DL ---
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch.helper import SliceDataset
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint

# --- Cartography / plotting ---
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

# --- Custom MBM modules ---
import massbalancemachine as mbm

# --- Warnings & autoreload (notebook) ---
warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2

from regions.Switzerland.scripts.geo_data import *
from regions.Switzerland.scripts.oggm import initialize_oggm_glacier_directories, export_oggm_grids
from regions.Switzerland.scripts.config_CH import * 
from regions.Switzerland.scripts.utils import * 

# --- Configuration ---
cfg = mbm.SwitzerlandConfig()

# Plot styles:
mbm.utils.seed_all(cfg.seed)
mbm.plots.use_mbm_style()

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    mbm.utils.free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="11",
    rgi_version="62",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L1-L2_files/2025.6/elev_bands_w_data/",
    log_level='WARNING',
    task_list=None,
)

# Save OGGM xr for all needed glaciers in RGI region 11.6:
df_missing = export_oggm_grids(cfg, gdirs, rgi_region="11")

# load RGI shapefile
gdf = gpd.read_file(cfg.dataPath + path_rgi_outlines)
# reproject to a local equal-area projection (example: EPSG:3035 for Europe)
gdf_proj = gdf.to_crs(3035)
gdf_proj.rename(columns={"RGIId": "rgi_id"}, inplace=True)
# gdf_proj.set_index('rgi_id', inplace=True)
gdf_proj["area_m2"] = gdf_proj.geometry.area
gdf_proj["area_km2"] = gdf_proj["area_m2"] / 1e6

df_missing = df_missing.merge(gdf_proj[['area_km2', 'rgi_id']], on="rgi_id")

# total glacier area
total_area = gdf_proj["area_km2"].sum()

# explode the list of missing vars into rows (one var per row)
df_exploded = df_missing.explode("missing_vars")

# 1) COUNT: number of glaciers missing each variable
counts_missing_per_var = (
    df_exploded.groupby("missing_vars")["rgi_id"].nunique().sort_values(
        ascending=False))

# 2) TOTAL % AREA with ANY missing var
total_missing_area_km2 = df_missing["area_km2"].sum()
total_missing_area_pct = (total_missing_area_km2 / total_area) * 100

print(f"Total glacier area with ANY missing variable: "
      f"{total_missing_area_km2:,.2f} km² "
      f"({total_missing_area_pct:.2f}%)")

# Optional: also show % area per variable (kept from your earlier logic)
area_missing_per_var = (
    df_exploded.groupby("missing_vars")["area_km2"].sum().sort_values(
        ascending=False))
perc_missing_per_var = (area_missing_per_var / total_area) * 100

print("\n% of total glacier area missing per variable:")
for var, pct in perc_missing_per_var.items():
    print(f"  - {var}: {pct:.2f}%")

# ---- barplot: number of glaciers missing each variable ----
plt.figure(figsize=(7, 4))
plt.bar(counts_missing_per_var.index, counts_missing_per_var.values)
plt.xlabel("Missing variable")
plt.ylabel("Number of glaciers")
plt.title("Count of glaciers missing each variable")
plt.tight_layout()
plt.show()

In [None]:
# RGI Ids:
# Read glacier ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
rgi_df.loc['rhone']

## Export geotifs of DEMs (needed for svf in separate notebook):

In [None]:
path_RGIs = os.path.join(cfg.dataPath, path_OGGM, "xr_grids/")
path_geotiff = os.path.join(cfg.dataPath, "RGI_v6/RGI_11_CentralEurope",
                            "geotiff/")

glaciers = os.listdir(path_RGIs)
print(f"Found {len(glaciers)} glaciers in RGI region 11")

RUN = False
if RUN:
    emptyfolder(path_geotiff)

    for gdir in tqdm(gdirs):
        rgi_gl = gdir.rgi_id

        try:
            # Export DEMs to GeoTIFF
            out_tif = export_glacier_dems_to_geotiff(path_RGIs, rgi_gl,
                                                     path_geotiff)
        except ValueError as e:
            print(f"Skipping {rgi_gl}: {e}")
            continue

## Create RGI grids for all glaciers:

### Create masked xarray grids:

In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed


def process_one_glacier(
    rgi_gl: str,
    path_RGIs: str,
    path_xr_svf: str,
    path_xr_grids: str,
    target_res_m: int = 50,
):
    """
    Worker: load OGGM grid, mask, optional coarsen, reproject to lat/lon,
    merge SVF, write per-glacier zarr. Returns a small status tuple.
    """
    try:
        # 1) Masked OGGM grid in projected coords
        ds, _ = create_masked_glacier_grid(path_RGIs, rgi_gl)

        # 2) Optional coarsen in projected space
        dx_m, dy_m = get_res_from_projected(ds)
        if 20 < dx_m < target_res_m:
            ds = coarsenDS_mercator(ds, target_res_m=target_res_m)

        # 3) Reproject to WGS84 lat/lon
        original_proj = ds.pyproj_srs
        ds = ds.rio.write_crs(original_proj)
        ds_latlon = ds.rio.reproject("EPSG:4326").rename({
            "x": "lon",
            "y": "lat"
        })

        # 4) Load SVF + merge (if exists)
        svf_path = os.path.join(path_xr_svf, f"{rgi_gl}_svf_latlon.nc")
        if os.path.exists(svf_path):
            ds_svf = xr.open_dataset(svf_path)

            # Normalize coord names
            if "x" in ds_svf.dims or "y" in ds_svf.dims:
                ds_svf = ds_svf.rename({"x": "lon", "y": "lat"})
            if "longitude" in ds_svf.dims or "latitude" in ds_svf.dims:
                ds_svf = ds_svf.rename({"longitude": "lon", "latitude": "lat"})

            # Sort ascending for interp stability
            if ds_latlon.lon[0] > ds_latlon.lon[-1]:
                ds_latlon = ds_latlon.sortby("lon")
            if ds_latlon.lat[0] > ds_latlon.lat[-1]:
                ds_latlon = ds_latlon.sortby("lat")
            if ds_svf.lon[0] > ds_svf.lon[-1]:
                ds_svf = ds_svf.sortby("lon")
            if ds_svf.lat[0] > ds_svf.lat[-1]:
                ds_svf = ds_svf.sortby("lat")

            svf_vars = [
                v for v in ("svf", "asvf", "opns") if v in ds_svf.data_vars
            ]

            if svf_vars:
                # Merge directly if grids match; else interpolate
                if (np.array_equal(ds_latlon.lon.values, ds_svf.lon.values)
                        and np.array_equal(ds_latlon.lat.values,
                                           ds_svf.lat.values)):
                    ds_latlon = xr.merge([ds_latlon, ds_svf[svf_vars]])
                else:
                    svf_on_grid = ds_svf[svf_vars].interp(lon=ds_latlon.lon,
                                                          lat=ds_latlon.lat,
                                                          method="linear")
                    for v in svf_vars:
                        svf_on_grid[v] = svf_on_grid[v].astype("float32")
                    ds_latlon = ds_latlon.assign(
                        **{v: svf_on_grid[v]
                           for v in svf_vars})

                # Masked SVF versions using glacier_mask (if present)
                if "glacier_mask" in ds_latlon:
                    gmask = xr.where(ds_latlon["glacier_mask"] == 1, 1.0,
                                     np.nan)
                    for v in svf_vars:
                        ds_latlon[f"masked_{v}"] = gmask * ds_latlon[v]

        # 5) Save final lat/lon grid
        os.makedirs(path_xr_grids, exist_ok=True)
        save_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
        ds_latlon.to_zarr(save_path, mode="w")

        return (rgi_gl, "ok", "")

    except Exception as e:
        return (rgi_gl, "error", f"{type(e).__name__}: {e}")


def run_parallel_processing(
    gdirs,
    path_RGIs,
    path_xr_svf,
    path_xr_grids,
    n_workers=None,
    clear_out=False,
    target_res_m=50,
):
    rgi_ids = [g.rgi_id for g in gdirs]

    if clear_out:
        emptyfolder(path_xr_grids)
    else:
        os.makedirs(path_xr_grids, exist_ok=True)

    results = []
    with ProcessPoolExecutor(max_workers=n_workers) as ex:
        futures = {
            ex.submit(
                process_one_glacier,
                rgi_id,
                path_RGIs,
                path_xr_svf,
                path_xr_grids,
                target_res_m,
            ):
            rgi_id
            for rgi_id in rgi_ids
        }

        for fut in tqdm(as_completed(futures), total=len(futures)):
            results.append(fut.result())

    # quick summary
    n_ok = sum(r[1] == "ok" for r in results)
    n_err = sum(r[1] == "error" for r in results)
    print(f"Done. ok={n_ok}, error={n_err}")

    if n_err:
        for rgi_id, status, msg in results:
            if status == "error":
                print(f"[{rgi_id}] {msg}")

    return results

In [None]:
path_xr_grids = os.path.join(cfg.dataPath, "RGI_v6/RGI_11_CentralEurope",
                             "xr_masked_grids/")
path_xr_svf = os.path.join(cfg.dataPath, "RGI_v6/RGI_11_CentralEurope",
                           "svf_nc_latlon/")

RUN = True
if RUN:
    results = run_parallel_processing(
        gdirs=gdirs,
        path_RGIs=path_RGIs,
        path_xr_svf=path_xr_svf,
        path_xr_grids=path_xr_grids,
        n_workers=6,  # start modest (4–8 is usually good)
        clear_out=True,  # or False if you want to keep existing zarrs
        target_res_m=50,
    )

In [None]:
rgi_id =  "RGI60-11.01238"
# --- Paths ---
dem_path = os.path.join(path_geotiff, f"{rgi_id}.tif")
zarr_path = os.path.join(path_xr_grids, f"{rgi_id}.zarr")
svf_path = os.path.join(path_xr_svf, f"{rgi_id}_svf_latlon.nc")

# --- Load data ---
dem = rioxarray.open_rasterio(dem_path).squeeze()
ds = xr.open_zarr(zarr_path)
ds_svf = xr.open_dataset(svf_path)

# Handle coord naming for SVF
if "lon" not in ds_svf.coords:
    ds_svf = ds_svf.rename({"x": "lon", "y": "lat"})

# --- Figure layout ---
fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)

# 1️⃣ DEM (projected)
dem.plot(ax=axes[0], cmap="terrain")
axes[0].set_title("DEM (projected meters)")
axes[0].set_xlabel("Easting [m]")
axes[0].set_ylabel("Northing [m]")

# 2️⃣ Masked aspect (projected OGGM grid)
ds["masked_aspect"].plot(ax=axes[1])
axes[1].set_title("Masked Aspect (°)")
axes[1].set_xlabel("Longitude (°)")
axes[1].set_ylabel("Latitude (°)")

# 3️⃣ SVF (lat/lon)
ds["svf"].plot(ax=axes[2])

axes[2].set_title("Sky View Factor (lat/lon)")
axes[2].set_xlabel("Longitude (°)")
axes[2].set_ylabel("Latitude (°)")

plt.suptitle(f"{rgi_id}", fontsize=15)
plt.show()

### Create monthly dataframes:

In [None]:
path_xr_grids = os.path.join(cfg.dataPath, "RGI_v6/RGI_08_Scandinavia",
                             "xr_masked_grids/")
path_xr_svf = os.path.join(cfg.dataPath, "RGI_v6/RGI_08_Scandinavia",
                           "svf_nc_latlon/")

RUN = True
if RUN:
    results = run_parallel_processing(
        gdirs=gdirs,
        path_RGIs=path_RGIs,
        path_xr_svf=path_xr_svf,
        path_xr_grids=path_xr_grids,
        n_workers=6,  # start modest (4–8 is usually good)
        clear_out=True,  # or False if you want to keep existing zarrs
        target_res_m=50,
    )

In [None]:
# Look at one example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.00001':
        gdir_rhone = gdir

rgi_gl = gdir_rhone.rgi_id

year = 2000
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

year = 2004
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

In [None]:
# Look at one example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

year = 2000
rgi_gl = gdir_rhone.rgi_id

df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
voi = [
    't2m', 'tp', 'ALTITUDE_CLIMATE', 'ELEVATION_DIFFERENCE', 'hugonnet_dhdt',
    'consensus_ice_thickness'
]
axs = axs.flatten()
for i, var in enumerate(voi):
    sns.scatterplot(df,
                    x='POINT_LON',
                    y='POINT_LAT',
                    hue=var,
                    s=5,
                    alpha=0.5,
                    palette='twilight_shifted',
                    ax=axs[i])