# centerflow_extruded

**centerflow_extruded** is a Python module for modeling glacier dynamics along flowlines raised to an extruded mesh, permitting higher-order modeling. It is mostly a direct analogue of the accompanying "centerflow" module. It provides tools with several core functionalities:

1. **id finder**
    Sometimes we need multiple IDs for a single glacier. It's useful to be able to discern the RGIv6 ID from the RGIv7 ID, for example.

2. **Mesh generation**  
   Construct a 1D finite element mesh along an RGI-defined glacier centerline, and extrude it to include a vertical coordinate, $\zeta$, which ranges from 0 at the base to 1 at the surface.

3. **Mesh trimming**
    In case of partial DEMs or other data limitations.

5. **Data interpolation**  
   Interpolate gridded geospatial datasets (e.g., surface elevation, velocity, surface mass balance) onto the extruded mesh.

6. **Function lifting**  
   Reproject functions back to the extruded mesh after they have been flattened to the base mesh.

7. **Smoothing**  
   Smooth functions defined on an extruded mesh.

8. **Bed inversion**  
   Apply a forward-model-based bed inversion scheme following the approach of [van Pelt et al. (2013)](https://tc.copernicus.org/articles/7/987/2013/), using observed surface elevations to iteratively estimate basal topography.

## Imports


In [None]:
from dataclasses import dataclass
import firedrake
from firedrake.mesh import ExtrudedMeshTopology
import geopandas as gpd
import icepack
import numpy as np
import pandas as pd
from pyproj import Geod
from pathlib import Path
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.crs import CRS
from rasterio.io import MemoryFile
import scipy.ndimage
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d
from shapely.geometry import LineString
from tqdm import trange

## rgi6_from_rgi7 etc.

In [None]:
def rgi6_from_rgi7(**kwargs):
    rgiid = kwargs["rgiid"]
    rgi6_path = kwargs["rgi6_path"]
    rgi7_path = kwargs["rgi7_path"]

    # Load RGI7 outlines and get the matching geometry
    gdf7 = gpd.read_file(rgi7_path)
    outline = gdf7[gdf7['rgi_id'].str.contains(rgiid, regex=False)].geometry.values[0]

    # Load RGI6 outlines and normalize ID column
    gdf6 = gpd.read_file(rgi6_path)
    if "RGIId" in gdf6.columns:
        gdf6 = gdf6[["RGIId", "geometry"]].rename(columns={"RGIId": "rgiid_6"})
    elif "rgi_id" in gdf6.columns:
        gdf6 = gdf6[["rgi_id", "geometry"]].rename(columns={"rgi_id": "rgiid_6"})
    else:
        raise ValueError(f"No RGI ID column found in {rgi6_path}")

    # Project to metric CRS for area calculation
    target_crs = "EPSG:32646"
    gdf6 = gdf6.to_crs(target_crs)
    outline_gdf = gpd.GeoDataFrame(geometry=[outline], crs=gdf7.crs).to_crs(target_crs)

    # Intersect and pick the largest overlap
    inter = gpd.overlay(gdf6, outline_gdf, how="intersection")
    if inter.empty:
        return None
    inter["overlap_area_m2"] = inter.geometry.area
    best_match = inter.sort_values("overlap_area_m2", ascending=False).iloc[0]["rgiid_6"]

    # Keep only the "15.xxxxx" part
    return best_match.split("-", 1)[-1]


def rgi7_from_rgi6(**kwargs):
    rgiid = kwargs["rgiid"]
    rgi6_path = kwargs["rgi6_path"]
    rgi7_path = kwargs["rgi7_path"]

    # Load RGI6 outlines and get the matching geometry
    gdf6 = gpd.read_file(rgi6_path)
    if "RGIId" in gdf6.columns:
        gdf6 = gdf6[["RGIId", "geometry"]].rename(columns={"RGIId": "rgiid_6"})
    elif "rgi_id" in gdf6.columns:
        gdf6 = gdf6[["rgi_id", "geometry"]].rename(columns={"rgi_id": "rgiid_6"})
    else:
        raise ValueError(f"No RGI ID column found in {rgi6_path}")
    outline = gdf6[gdf6['rgiid_6'].str.contains(rgiid, regex=False)].geometry.values[0]

    # Load RGI7 outlines
    gdf7 = gpd.read_file(rgi7_path)[["rgi_id", "geometry"]]

    # Project to metric CRS for area calculation
    target_crs = "EPSG:32646"
    gdf7 = gdf7.to_crs(target_crs)
    outline_gdf = gpd.GeoDataFrame(geometry=[outline], crs=gdf6.crs).to_crs(target_crs)

    # Intersect and pick the largest overlap
    inter = gpd.overlay(gdf7, outline_gdf, how="intersection")
    if inter.empty:
        return None
    inter["overlap_area_m2"] = inter.geometry.area
    best_match = inter.sort_values("overlap_area_m2", ascending=False).iloc[0]["rgi_id"]

    # Keep only the "15-xxxxx" part
    return "-".join(best_match.split("-")[3:])

def latlon_from_rgi7(**kwargs):
    rgiid_7   = kwargs["rgiid"]
    rgi7_path = kwargs["rgi7_path"]

    # Load and filter
    gdf7 = gpd.read_file(rgi7_path)
    match = gdf7[gdf7['rgi_id'].str.contains(rgiid_7, regex=False)]
    if match.empty:
        raise ValueError(f"RGI7 ID '{rgiid_7}' not found in {rgi7_path}")

    # Reproject to a metric CRS for centroid calculation
    metric_crs = "EPSG:32646"  # Bhutan region
    centroid_metric = match.to_crs(metric_crs).geometry.centroid.iloc[0]

    # Convert centroid back to geographic CRS
    centroid_geo = gpd.GeoSeries([centroid_metric], crs=metric_crs).to_crs("EPSG:4326").iloc[0]
    lat, lon = centroid_geo.y, centroid_geo.x

    # Build tile string
    lat_prefix = "N" if lat >= 0 else "S"
    lon_prefix = "E" if lon >= 0 else "W"
    lat_deg = int(np.floor(np.abs(lat)))
    lon_deg = int(np.floor(np.abs(lon)))
    tile = f"{lat_prefix}{lat_deg:02d}{lon_prefix}{lon_deg:03d}"

    return tile

## centerline_mesh

In [None]:
@dataclass
class IntervalMeshResult:
    mesh: firedrake.ExtrudedMesh
    x: np.ndarray
    y: np.ndarray
    X: np.ndarray
    basal_coords: np.ndarray
    surface_coords: np.ndarray
    glacier_length: float
    centerline: object
    outline: object


def centerline_mesh(**kwargs):
    rgiid = kwargs.get('rgiid', '15-09534')
    centerline_path = kwargs['centerline_path']
    outline_path = kwargs['outline_path']
    n_cells = kwargs['n_cells']

    outlines = gpd.read_file(outline_path)
    centerlines = gpd.read_file(centerline_path)
    outline = outlines[outlines['rgi_id'].str.contains(rgiid)].geometry.values[0]
    flowlines = centerlines[centerlines.intersects(outline)] #may contain severeral smaller tributary flowlines
    centerline = flowlines.loc[flowlines.to_crs('EPSG:32646').length.idxmax(), 'geometry'] #so grab the longest one

    geod = Geod(ellps = 'WGS84')
    x, y = centerline.xy
    distances = np.insert(np.cumsum([
        geod.inv(x[i], y[i], x[i + 1], y[i + 1])[2]
        for i in range(len(x) - 1)
    ]), 0, 0)
    glacier_length = distances[-1]

    # Interpolate x, y to match n_cells + 1 vertices
    uniform_dists = np.linspace(0, glacier_length, n_cells + 1)
    x_interp = interp1d(distances, x)(uniform_dists)
    y_interp = interp1d(distances, y)(uniform_dists)

    base_mesh = firedrake.IntervalMesh(n_cells, glacier_length)
    mesh = firedrake.ExtrudedMesh(base_mesh, layers = 1)
    X = base_mesh.coordinates.dat.data_ro.flatten()
    Z_1 = np.ones_like(X)
    Z_0 = np.zeros_like(X)
    surface_coords = np.column_stack((X, Z_1))
    basal_coords = np.column_stack((X, Z_0))

    return IntervalMeshResult(
        mesh = mesh,
        x = np.array(x_interp),
        y = np.array(y_interp),
        X = X,
        surface_coords = surface_coords,
        basal_coords = basal_coords,
        glacier_length = glacier_length,
        centerline = centerline,
        outline = outline
    )

## crop_mesh

In [None]:
def _cut(line, d):
    if d <= 0.0:
        return [None, LineString(line.coords)]
    if d >= line.length:
        return [LineString(line.coords), None]
    coords, acc = list(line.coords), 0.0
    for i in range(len(coords) - 1):
        p0, p1 = coords[i], coords[i + 1]
        seg = LineString([p0, p1])
        L = seg.length
        if acc + L >= d:
            t = (d - acc) / L
            x = p0[0] + t * (p1[0] - p0[0])
            y = p0[1] + t * (p1[1] - p0[1])
            pt = (x, y)
            return [LineString(coords[:i + 1] + [pt]), LineString([pt] + coords[i + 1:])]
        acc += L
    return [LineString(coords), None]

def _segment(line, d0, d1):
    left, mid = _cut(line, d0)
    mid, right = _cut(mid, d1 - d0)
    return mid

def crop_mesh(**kwargs):
    mesh = kwargs['mesh']
    data_path = kwargs['data_path']

    with rasterio.open(data_path) as src:
        r_crs = src.crs
        nodata = src.nodata

    cl_m = gpd.GeoSeries([mesh.centerline], crs='EPSG:4326').to_crs(r_crs).iloc[0]

    target_samples = 400
    step = max(cl_m.length / target_samples, 25.0)
    dists = np.arange(0.0, cl_m.length + step, step, dtype=float)
    pts = [cl_m.interpolate(float(d)) for d in dists]

    with rasterio.open(data_path) as src:
        vals = np.array([v[0] for v in src.sample([(p.x, p.y) for p in pts])])

    valid = ~np.isnan(vals) if nodata is None else (vals != nodata)
    if not valid.any():
        raise ValueError('No valid data found along centerline for the provided raster.')

    first = int(np.argmax(valid))
    last = int(len(valid) - 1 - np.argmax(valid[::-1]))

    # If no cropping is needed, return unchanged mesh
    if first == 0 and last == len(valid) - 1:
        return mesh

    d0, d1 = float(dists[first]), float(dists[last])
    cl_m_cropped = _segment(cl_m, d0, d1)
    cl_ll = gpd.GeoSeries([cl_m_cropped], crs=r_crs).to_crs('EPSG:4326').iloc[0]

    geod = Geod(ellps='WGS84')
    xs, ys = np.asarray(cl_ll.xy[0]), np.asarray(cl_ll.xy[1])
    segs = [geod.inv(xs[i], ys[i], xs[i + 1], ys[i + 1])[2] for i in range(len(xs) - 1)]
    new_len = float(np.sum(segs))

    n_cells = len(mesh.X) - 1
    base_mesh = firedrake.IntervalMesh(n_cells, new_len)
    extruded = firedrake.ExtrudedMesh(base_mesh, layers=1)
    X = base_mesh.coordinates.dat.data_ro.flatten()
    surface_coords = np.column_stack((X, np.ones_like(X)))
    basal_coords = np.column_stack((X, np.zeros_like(X)))

    return type(mesh)(
        mesh = extruded,
        x = xs,
        y = ys,
        X = X,
        basal_coords = basal_coords,
        surface_coords = surface_coords,
        glacier_length = new_len,
        centerline = cl_ll,
        outline = mesh.outline,
    )


## map_to_mesh

In [None]:
def map_to_mesh(**kwargs):
    mesh = kwargs['mesh']
    data_path = kwargs['data_path']
    extension = Path(data_path).suffix.lower()
    dimension = kwargs.get('dimension', 1)
    element = kwargs.get('element', 'CG')
    key_value = kwargs.get('key_value', 'n/a')
    data_value = kwargs.get('data_value', 'n/a')
    key_dataset = kwargs.get('key_dataset', None)
    projection = kwargs.get('projection', 'EPSG:4326')

    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()

    if extension == '.tif':
        x, y = mesh.x, mesh.y

        with rasterio.open(data_path) as src:
            src_crs = src.crs
            target_crs = CRS.from_string(projection)

            if src_crs != target_crs:
                print(f'Reprojecting {data_path} from {src_crs} to {target_crs}')
                transform, width, height = calculate_default_transform(
                    src_crs, target_crs, src.width, src.height, *src.bounds)

                meta = src.meta.copy()
                meta.update({
                    'crs': target_crs,
                    'transform': transform,
                    'width': width,
                    'height': height
                })

                with MemoryFile() as memfile:
                    with memfile.open(**meta) as dst:
                        for i in range(1, src.count + 1):
                            reproject(
                                source = rasterio.band(src, i),
                                destination = rasterio.band(dst, i),
                                src_transform = src.transform,
                                src_crs = src_crs,
                                dst_transform = transform,
                                dst_crs = target_crs,
                                resampling = Resampling.bilinear
                            )
                    with memfile.open() as reproj:
                        values = np.array(list(reproj.sample(zip(x, y)))).flatten()
            else:
                values = np.array(list(src.sample(zip(x, y)))).flatten()

        geod = Geod(ellps = 'WGS84')
        distances = np.insert(np.cumsum([
            geod.inv(x[i], y[i], x[i + 1], y[i + 1])[2] for i in range(len(x) - 1)
        ]), 0, 0)

        interp_func = interp1d(distances, values, bounds_error = False, fill_value = 'extrapolate')
        interp_vals = interp_func(base_coords)

    elif extension == '.csv':
        df = pd.read_csv(data_path)
        if key_value not in df.columns or data_value not in df.columns:
            raise ValueError(f'CSV must include {key_value!r} and {data_value!r} columns.')
        if key_dataset is None:
            raise ValueError('Must provide "key_dataset" (a Firedrake Function) when using CSV input.')

        key_array = key_dataset.dat.data_ro
        interp_func = interp1d(df[key_value], df[data_value], bounds_error = False, fill_value = 'extrapolate')
        interp_vals = interp_func(key_array)

    else:
        raise ValueError(f'Unsupported file extension: {extension}')

    V = firedrake.FunctionSpace(mesh.mesh, element, dimension, vfamily = 'R', vdegree = 0)
    data_function = firedrake.Function(V)
    data_function.dat.data[:] = interp_vals

    return data_function

# lift_to_mesh

This can be useful for function smoothing, as smoothing requires first flattening a function and then applying a gaussian filter. The resulting smooth function then needs to be reprojected to the extruded mesh. Possibly this may also be useful for other reasons too, but I haven't found any. 

In [None]:
def lift_to_mesh(**kwargs):
    input_function = kwargs['function']
    function_space = kwargs['function_space']
    mesh = kwargs['mesh']

    element = function_space.ufl_element()
    horizontal_elem, vertical_elem = element.sub_elements
    hfamily, vfamily = horizontal_elem.family(), vertical_elem.family()
    hdegree, vdegree = horizontal_elem.degree(), vertical_elem.degree()

    # Fast path: vertically uniform elements (e.g., CG2xR0)
    if vfamily == 'R' and vdegree == 0:
        f = firedrake.Function(function_space)
        base_vals = input_function.dat.data_ro
        f_data = f.dat.data

        n_base = len(base_vals)
        n_total = f_data.shape[0]
        n_layers = n_total // n_base

        for layer in range(n_layers):
            f_data[layer::n_layers] = base_vals

        return f

    # Fallback: Interpolate base values at mesh coordinates
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()
    input_coords = np.linspace(0, mesh.glacier_length, len(input_function.dat.data_ro))
    interp_func = interp1d(input_coords, input_function.dat.data_ro, bounds_error = False, fill_value = 'extrapolate')
    interp_vals = interp_func(base_coords)

    function_space_interp = firedrake.FunctionSpace(
        mesh.mesh, hfamily, 1, vfamily = vfamily, vdegree = vdegree
    )
    f = firedrake.Function(function_space_interp)
    f.dat.data[:] = interp_vals

    function_space_proj = firedrake.FunctionSpace(
        mesh.mesh, hfamily, hdegree, vfamily = vfamily, vdegree = vdegree
    )
    return firedrake.project(f, function_space_proj)

# smooth_extruded_function

Important for icepack's HO model, which can be somewhat finicky with convergence, especially when there's a lot of jitter in certain fields. 

In [None]:
def smooth_extruded_function(**kwargs):
    f = kwargs['function']
    mesh = kwargs['mesh']
    sigma = kwargs.get('sigma', None)
    window_meters = kwargs.get('window', None)

    f_flat = icepack.depth_average(f)
    f_data = f_flat.dat.data_ro.copy()

    # Compute dx (spacing in meters) from base mesh
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()
    dx = np.mean(np.diff(base_coords))

    if sigma is None:
        if window_meters is None:
            raise ValueError('Must specify either "sigma" or "window_meters"')
        sigma = window_meters / dx / 2.0  # Approximate: 2σ ≈ full width at half maximum

    smoothed_data = scipy.ndimage.gaussian_filter1d(f_data, sigma = sigma)

    f_flat_smoothed = f_flat.copy(deepcopy = True)
    f_flat_smoothed.dat.data[:] = smoothed_data

    f_smoothed = lift_to_mesh(
        function = f_flat_smoothed,
        mesh = mesh,
        function_space = f.function_space()
    )

    return f_smoothed

## solve_bed

In [None]:
@dataclass
class InversionResult:
    bed: firedrake.Function
    misfits: list
    bed_evolution: list
    surface_evolution: list
    thickness_evolution: list
    velocity_evolution: list
    s_ref: firedrake.Function


def solve_bed(**kwargs):
    # Inputs
    mesh = kwargs['mesh']
    s_init = kwargs['surface']
    s_ref = kwargs.get('surface_2', s_init)       
    H_guess = kwargs['thickness_guess']
    u_guess = kwargs['velocity']
    a = kwargs['accumulation']
    A = kwargs['fluidity']
    K = kwargs['K']
    num_iterations = kwargs['num_iterations']
    model = kwargs['model']
    solver = kwargs['solver']
    friction = kwargs['friction']

    try: num_years = s_ref.year - s_init.year #extract the time diff from the DEMs, if applicable
    except: num_years = kwargs['model_time'] #otherwise, need to choose how long to model for 

    try: Δt = round(list(a)[1] - list(a)[0], 10) #extract Δt from the SMB list, if applicable
    except: Δt = kwargs['timestep'] #otherwise, it needs to be specified

    # Function space and coordinates
    Q = s_init.function_space()
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()

    # Initial bed guess
    bed_guess = firedrake.Function(Q).project(s_init - H_guess)

    # Initialize storage
    misfits = []
    bed_evolution = [bed_guess.dat.data_ro.copy()]
    surface_evolution = []
    velocity_evolution = []
    thickness_evolution = []

    num_timesteps = int(num_years/Δt)

    bed_correction = firedrake.Function(Q)
    surface_misfit = firedrake.Function(Q)

    for iteration in trange(num_iterations):
        bed_mod = bed_guess.copy(deepcopy = True)
        H_mod = firedrake.Function(Q).project(s_init - bed_mod)
        H_0 = H_mod.copy(deepcopy = True)
        u_mod = u_guess.copy(deepcopy = True)
        s_mod = s_init.copy(deepcopy = True)

        for step in range(num_timesteps):

            try: accumulation = a[s_init.year + step*Δt] #if SMB is a dictionary with date keys
            except: accumulation = a #otherwise

            try:
                u_mod = solver.diagnostic_solve(
                    velocity = u_mod,
                    thickness = H_mod,
                    surface = s_mod,
                    fluidity = A,
                    friction = friction
                )
                
                H_mod = solver.prognostic_solve(
                    Δt,
                    thickness = H_mod,
                    velocity = u_mod,
                    thickness_inflow = H_0,
                    accumulation = accumulation,
                )
                s_mod.project(icepack.compute_surface(bed = bed_mod, thickness = H_mod))

            except:
                print(f'Bed solver failed on step {step + 1} of iteration {iteration + 1}')

                return InversionResult(
                    bed = bed_guess,
                    misfits = misfits,
                    bed_evolution = bed_evolution,
                    surface_evolution = surface_evolution,
                    velocity_evolution = velocity_evolution,
                    thickness_evolution = thickness_evolution,
                    s_ref = s_ref
                )

        surface_misfit.project(s_mod - s_ref)
        bed_correction.project(-K * surface_misfit)
        bed_guess.project(bed_mod + bed_correction)

        # Store evolution values
        misfits.append(float(firedrake.assemble(surface_misfit*firedrake.dx)/mesh.glacier_length))
        # misfits.append(np.linalg.norm(surface_misfit.dat.data_ro))
        bed_evolution.append(bed_guess.dat.data_ro.copy())
        surface_evolution.append(s_mod.dat.data_ro.copy())
        thickness_evolution.append(H_mod.dat.data_ro.copy())
        velocity_lifted = lift_to_mesh(
            function = icepack.depth_average(u_mod),
            mesh = mesh,
            function_space = s_mod.function_space()
        )
        velocity_evolution.append(velocity_lifted.dat.data_ro.copy())

    return InversionResult(
        bed = bed_guess,
        misfits = misfits,
        bed_evolution = bed_evolution,
        surface_evolution = surface_evolution,
        velocity_evolution = velocity_evolution,
        thickness_evolution = thickness_evolution,
        s_ref = s_ref
    )

In [None]:
# !jupyter nbconvert --to script centerflow_extruded.ipynb