# Glacier grids from RGI:

Creates monthly grid files for the MBM to make PMB predictions over the whole glacier grid. The files come from the RGI grid with OGGM topography. Computing takes a long time because of the conversion to monthly format.
## Setting up:

In [None]:
# --- System & utilities ---
import os
import sys
import warnings
from tqdm.notebook import tqdm
import rioxarray

# Add repo root for MBM imports
sys.path.append(os.path.join(os.getcwd(), "../../"))

# --- Data science stack ---
import matplotlib.pyplot as plt

# --- Custom MBM modules ---
import massbalancemachine as mbm

# --- Warnings & autoreload (notebook) ---
warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2


# --- Configuration ---
cfg = mbm.EuropeConfig()

from regions.TF_Europe.scripts.config_TF_Europe import *
from regions.TF_Europe.scripts.oggm import *
from regions.TF_Europe.scripts.geodata import *

# Plot styles:
mbm.utils.seed_all(cfg.seed)
mbm.plots.use_mbm_style()

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    mbm.utils.free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Create RGI grids for all glaciers:

### Create masked xarray grids: 
(need to have computed svf separately for this)

In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import glob


def process_one_glacier(
    rgi_gl: str,
    path_RGIs: str,
    path_xr_svf: str,
    path_xr_grids: str,
    target_res_m: int = 50,
):
    """
    Worker: load OGGM grid, mask, optional coarsen, reproject to lat/lon,
    merge SVF, write per-glacier zarr. Returns a small status tuple.
    """
    try:
        # 1) Masked OGGM grid in projected coords
        ds, _ = create_masked_glacier_grid(path_RGIs, rgi_gl)

        # 2) Optional coarsen in projected space
        dx_m, dy_m = get_res_from_projected(ds)
        if 20 < dx_m < target_res_m:
            ds = coarsenDS_mercator(ds, target_res_m=target_res_m)

        # 3) Reproject to WGS84 lat/lon
        original_proj = ds.pyproj_srs
        ds = ds.rio.write_crs(original_proj)
        ds_latlon = ds.rio.reproject("EPSG:4326").rename({
            "x": "lon",
            "y": "lat"
        })

        # 4) Load SVF + merge (if exists)
        svf_path = os.path.join(path_xr_svf, f"{rgi_gl}_svf_latlon.nc")
        if os.path.exists(svf_path):
            with xr.open_dataset(svf_path) as ds_svf:
                ds_svf = ds_svf.load(
                )  # optional: helps avoid file-handle issues in multiprocessing

                # Normalize coord names
                if "x" in ds_svf.dims or "y" in ds_svf.dims:
                    ds_svf = ds_svf.rename({"x": "lon", "y": "lat"})
                if "longitude" in ds_svf.dims or "latitude" in ds_svf.dims:
                    ds_svf = ds_svf.rename({
                        "longitude": "lon",
                        "latitude": "lat"
                    })

            # Sort ascending for interp stability
            if ds_latlon.lon[0] > ds_latlon.lon[-1]:
                ds_latlon = ds_latlon.sortby("lon")
            if ds_latlon.lat[0] > ds_latlon.lat[-1]:
                ds_latlon = ds_latlon.sortby("lat")
            if ds_svf.lon[0] > ds_svf.lon[-1]:
                ds_svf = ds_svf.sortby("lon")
            if ds_svf.lat[0] > ds_svf.lat[-1]:
                ds_svf = ds_svf.sortby("lat")

            svf_vars = [
                v for v in ("svf", "asvf", "opns") if v in ds_svf.data_vars
            ]

            if svf_vars:
                # Merge directly if grids match; else interpolate
                if (np.array_equal(ds_latlon.lon.values, ds_svf.lon.values)
                        and np.array_equal(ds_latlon.lat.values,
                                           ds_svf.lat.values)):
                    ds_latlon = xr.merge([ds_latlon, ds_svf[svf_vars]])
                else:
                    svf_on_grid = ds_svf[svf_vars].interp(lon=ds_latlon.lon,
                                                          lat=ds_latlon.lat,
                                                          method="linear")
                    for v in svf_vars:
                        svf_on_grid[v] = svf_on_grid[v].astype("float32")
                    ds_latlon = ds_latlon.assign(
                        **{v: svf_on_grid[v]
                           for v in svf_vars})

                # Masked SVF versions using glacier_mask (if present)
                if "glacier_mask" in ds_latlon:
                    gmask = xr.where(ds_latlon["glacier_mask"] == 1, 1.0,
                                     np.nan)
                    for v in svf_vars:
                        ds_latlon[f"masked_{v}"] = gmask * ds_latlon[v]

        # 5) Save final lat/lon grid
        save_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
        ds_latlon.to_zarr(save_path, mode="w")

        return (rgi_gl, "ok", "")

    except Exception as e:
        return (rgi_gl, "error", f"{type(e).__name__}: {e}")


def glacier_ids_from_xr_grids(path_RGIs: str):
    """
    If you have per-glacier OGGM zarr grids saved as RGI60-..xxxx.zarr
    in path_RGIs, return those ids.
    """
    zarrs = sorted(glob.glob(os.path.join(path_RGIs, "*.zarr")))
    return [os.path.splitext(os.path.basename(p))[0]
            for p in zarrs]  # stem before .zarr


def run_parallel_processing_region(
    rgi_ids,
    path_RGIs,
    path_xr_svf,
    path_xr_grids,
    n_workers=6,
    clear_out=False,
    target_res_m=50,
):
    if clear_out:
        emptyfolder(path_xr_grids)
    else:
        os.makedirs(path_xr_grids, exist_ok=True)

    # your existing executor code, but using rgi_ids directly:
    results = []
    with ProcessPoolExecutor(max_workers=n_workers) as ex:
        futures = {
            ex.submit(
                process_one_glacier,
                rgi_id,
                path_RGIs,
                path_xr_svf,
                path_xr_grids,
                target_res_m,
            ):
            rgi_id
            for rgi_id in rgi_ids
        }

        for fut in tqdm(as_completed(futures), total=len(futures)):
            results.append(fut.result())

    n_ok = sum(r[1] == "ok" for r in results)
    n_err = sum(r[1] == "error" for r in results)
    print(f"Done. ok={n_ok}, error={n_err}")

    if n_err:
        for rgi_id, status, msg in results:
            if status == "error":
                print(f"[{rgi_id}] {msg}")

    return results


def run_all_regions(
    cfg,
    RGI_REGIONS,
    n_workers=6,
    clear_out=False,
    target_res_m=50,
):
    all_results = {}

    for rgi_region, spec in RGI_REGIONS.items():
        region_folder = spec["folder"]
        print(f"\n========== RGI {rgi_region} ({spec['name']}) ==========")

        # Inputs/outputs per region
        path_xr_svf = os.path.join(cfg.dataPath, "RGI_v6", region_folder,
                                   "svf_nc_latlon")
        path_xr_grids_out = os.path.join(cfg.dataPath, "RGI_v6", region_folder,
                                         "xr_masked_grids")

        # This is your OGGM xr grid input folder (where create_masked_glacier_grid reads from)
        # Adjust if your layout differs!
        path_RGIs = os.path.join(
            cfg.dataPath, f"OGGM/rgi_region_{str(rgi_region).zfill(2)}",
            "xr_grids")

        if not os.path.isdir(path_RGIs):
            print(f"Skipping: missing xr_grids folder: {path_RGIs}")
            continue
        if not os.path.isdir(path_xr_svf):
            print(f"Skipping: missing svf folder: {path_xr_svf}")
            continue

        rgi_ids = glacier_ids_from_xr_grids(path_RGIs)
        print(f"Found {len(rgi_ids)} glacier grids in {path_RGIs}")

        if len(rgi_ids) == 0:
            continue

        # Create output folder once (process_one_glacier won't do it anymore)
        os.makedirs(path_xr_grids_out, exist_ok=True)

        # Run parallel within region
        results = run_parallel_processing_region(
            rgi_ids=rgi_ids,
            path_RGIs=path_RGIs,
            path_xr_svf=path_xr_svf,
            path_xr_grids=path_xr_grids_out,
            n_workers=n_workers,
            clear_out=clear_out,
            target_res_m=target_res_m,
        )
        all_results[rgi_region] = results

    return all_results

In [None]:
RUN = True
if RUN:
    all_results = run_all_regions(
        cfg=cfg,
        RGI_REGIONS=RGI_REGIONS,
        n_workers=6,
        clear_out=True,
        target_res_m=50,
    )

In [None]:
# Look at an example:

rgi_id = "RGI60-07.01615"
region_id = "07"
# --- Paths ---
basepath = os.path.join(cfg.dataPath, "RGI_v6",
                        RGI_REGIONS[region_id]["folder"])
dem_path = os.path.join(basepath, "geotiff", f"{rgi_id}.tif")
zarr_path = os.path.join(basepath, "xr_masked_grids", f"{rgi_id}.zarr")
svf_path = os.path.join(basepath, "svf_nc_latlon", f"{rgi_id}_svf_latlon.nc")

# --- Load data ---
dem = rioxarray.open_rasterio(dem_path).squeeze()
ds = xr.open_zarr(zarr_path)
ds_svf = xr.open_dataset(svf_path)

# Handle coord naming for SVF
if "lon" not in ds_svf.coords:
    ds_svf = ds_svf.rename({"x": "lon", "y": "lat"})

# --- Figure layout ---
fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)

# 1️⃣ DEM (projected)
dem.plot(ax=axes[0], cmap="terrain")
axes[0].set_title("DEM (projected meters)")
axes[0].set_xlabel("Easting [m]")
axes[0].set_ylabel("Northing [m]")

# 2️⃣ Masked aspect (projected OGGM grid)
ds["masked_aspect"].plot(ax=axes[1])
axes[1].set_title("Masked Aspect (°)")
axes[1].set_xlabel("Longitude (°)")
axes[1].set_ylabel("Latitude (°)")

# 3️⃣ SVF (lat/lon)
ds["svf"].plot(ax=axes[2])

axes[2].set_title("Sky View Factor (lat/lon)")
axes[2].set_xlabel("Longitude (°)")
axes[2].set_ylabel("Latitude (°)")

plt.suptitle(f"{rgi_id}", fontsize=15)
plt.show()

In [None]:
# Look at an example:

rgi_id = "RGI60-08.00630"
region_id = "08"
# --- Paths ---
basepath = os.path.join(cfg.dataPath, "RGI_v6",
                        RGI_REGIONS[region_id]["folder"])
dem_path = os.path.join(basepath, "geotiff", f"{rgi_id}.tif")
zarr_path = os.path.join(basepath, "xr_masked_grids", f"{rgi_id}.zarr")
svf_path = os.path.join(basepath, "svf_nc_latlon", f"{rgi_id}_svf_latlon.nc")

# --- Load data ---
dem = rioxarray.open_rasterio(dem_path).squeeze()
ds = xr.open_zarr(zarr_path)
ds_svf = xr.open_dataset(svf_path)

# Handle coord naming for SVF
if "lon" not in ds_svf.coords:
    ds_svf = ds_svf.rename({"x": "lon", "y": "lat"})

# --- Figure layout ---
fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)

# 1️⃣ DEM (projected)
dem.plot(ax=axes[0], cmap="terrain")
axes[0].set_title("DEM (projected meters)")
axes[0].set_xlabel("Easting [m]")
axes[0].set_ylabel("Northing [m]")

# 2️⃣ Masked aspect (projected OGGM grid)
ds["masked_aspect"].plot(ax=axes[1])
axes[1].set_title("Masked Aspect (°)")
axes[1].set_xlabel("Longitude (°)")
axes[1].set_ylabel("Latitude (°)")

# 3️⃣ SVF (lat/lon)
ds["svf"].plot(ax=axes[2])

axes[2].set_title("Sky View Factor (lat/lon)")
axes[2].set_xlabel("Longitude (°)")
axes[2].set_ylabel("Latitude (°)")

plt.suptitle(f"{rgi_id}", fontsize=15)
plt.show()