# centerflow_extruded

**centerflow_extruded** is a Python module for modeling glacier dynamics along flowlines raised to an extruded mesh, permitting higher-order modeling. It provides tools with five core functionalities:

1. **Mesh generation**  
   Construct a 1D finite element mesh along an RGI-defined glacier centerline, and extrude it to include a vertical coordinate, $\zeta$, which ranges from 0 at the base to 1 at the surface.

2. **Data interpolation**  
   Interpolate gridded geospatial datasets (e.g., surface elevation, velocity, surface mass balance) onto the extruded mesh.

3. **Function lifting**  
   Reproject functions back to the extruded mesh after they have been flattened to the base mesh.

4. **Smoothing**  
   Smooth functions defined on an extruded mesh.

5. **Bed inversion**  
   Apply a forward-model-based bed inversion scheme following the approach of [van Pelt et al. (2013)](https://tc.copernicus.org/articles/7/987/2013/), using observed surface elevations to iteratively estimate basal topography.

## Imports


In [None]:
from dataclasses import dataclass
import firedrake
from firedrake.mesh import ExtrudedMeshTopology
import geopandas as gpd
import icepack
import numpy as np
import pandas as pd
from pyproj import Geod
from pathlib import Path
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.crs import CRS
from rasterio.io import MemoryFile
import scipy.ndimage
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d
from tqdm import trange

## centerline_mesh

In [None]:
@dataclass
class IntervalMeshResult:
    mesh: firedrake.ExtrudedMesh
    x: np.ndarray
    y: np.ndarray
    glacier_length: float


def centerline_mesh(**kwargs):
    rgiid = kwargs.get('rgiid', '15-09534')
    centerline_path = kwargs['centerline_path']
    outline_path = kwargs['outline_path']
    n_cells = kwargs['n_cells']

    outlines = gpd.read_file(outline_path)
    centerlines = gpd.read_file(centerline_path)
    outline = outlines[outlines['rgi_id'].str.contains(rgiid)].geometry.values[0]
    centerline = centerlines[centerlines.intersects(outline)].geometry.values[0]

    geod = Geod(ellps = 'WGS84')
    x, y = centerline.xy
    distances = np.insert(np.cumsum([
        geod.inv(x[i], y[i], x[i + 1], y[i + 1])[2]
        for i in range(len(x) - 1)
    ]), 0, 0)
    glacier_length = distances[-1]

    # Interpolate x, y to match n_cells + 1 vertices
    uniform_dists = np.linspace(0, glacier_length, n_cells + 1)
    x_interp = interp1d(distances, x)(uniform_dists)
    y_interp = interp1d(distances, y)(uniform_dists)

    base_mesh = firedrake.IntervalMesh(n_cells, glacier_length)
    mesh = firedrake.ExtrudedMesh(base_mesh, layers = 1)

    return IntervalMeshResult(
        mesh = mesh,
        x = np.array(x_interp),
        y = np.array(y_interp),
        glacier_length = glacier_length
    )

## map_to_mesh

In [None]:
def map_to_mesh(**kwargs):
    mesh = kwargs['mesh']
    data_path = kwargs['data_path']
    extension = Path(data_path).suffix.lower()
    dimension = kwargs.get('dimension', 1)
    element = kwargs.get('element', 'CG')
    key_value = kwargs.get('key_value', 'n/a')
    data_value = kwargs.get('data_value', 'n/a')
    key_dataset = kwargs.get('key_dataset', None)
    projection = kwargs.get('projection', 'EPSG:4326')

    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()

    if extension == '.tif':
        x, y = mesh.x, mesh.y

        with rasterio.open(data_path) as src:
            src_crs = src.crs
            target_crs = CRS.from_string(projection)

            if src_crs != target_crs:
                print(f'Reprojecting {data_path} from {src_crs} to {target_crs}')
                transform, width, height = calculate_default_transform(
                    src_crs, target_crs, src.width, src.height, *src.bounds)

                meta = src.meta.copy()
                meta.update({
                    'crs': target_crs,
                    'transform': transform,
                    'width': width,
                    'height': height
                })

                with MemoryFile() as memfile:
                    with memfile.open(**meta) as dst:
                        for i in range(1, src.count + 1):
                            reproject(
                                source = rasterio.band(src, i),
                                destination = rasterio.band(dst, i),
                                src_transform = src.transform,
                                src_crs = src_crs,
                                dst_transform = transform,
                                dst_crs = target_crs,
                                resampling = Resampling.bilinear
                            )
                    with memfile.open() as reproj:
                        values = np.array(list(reproj.sample(zip(x, y)))).flatten()
            else:
                values = np.array(list(src.sample(zip(x, y)))).flatten()

        geod = Geod(ellps = 'WGS84')
        distances = np.insert(np.cumsum([
            geod.inv(x[i], y[i], x[i + 1], y[i + 1])[2] for i in range(len(x) - 1)
        ]), 0, 0)

        interp_func = interp1d(distances, values, bounds_error = False, fill_value = 'extrapolate')
        interp_vals = interp_func(base_coords)

    elif extension == '.csv':
        df = pd.read_csv(data_path)
        if key_value not in df.columns or data_value not in df.columns:
            raise ValueError(f'CSV must include {key_value!r} and {data_value!r} columns.')
        if key_dataset is None:
            raise ValueError('Must provide "key_dataset" (a Firedrake Function) when using CSV input.')

        key_array = key_dataset.dat.data_ro
        interp_func = interp1d(df[key_value], df[data_value], bounds_error = False, fill_value = 'extrapolate')
        interp_vals = interp_func(key_array)

    else:
        raise ValueError(f'Unsupported file extension: {extension}')

    V = firedrake.FunctionSpace(mesh.mesh, element, dimension, vfamily = 'R', vdegree = 0)
    data_function = firedrake.Function(V)
    data_function.dat.data[:] = interp_vals

    return data_function

# lift_to_mesh

In [None]:
def lift_to_mesh(**kwargs):
    input_function = kwargs['function']
    function_space = kwargs['function_space']
    mesh = kwargs['mesh']

    element = function_space.ufl_element()
    horizontal_elem, vertical_elem = element.sub_elements
    hfamily, vfamily = horizontal_elem.family(), vertical_elem.family()
    hdegree, vdegree = horizontal_elem.degree(), vertical_elem.degree()

    # Fast path: vertically uniform elements (e.g., CG2xR0)
    if vfamily == 'R' and vdegree == 0:
        f = firedrake.Function(function_space)
        base_vals = input_function.dat.data_ro
        f_data = f.dat.data

        n_base = len(base_vals)
        n_total = f_data.shape[0]
        n_layers = n_total // n_base

        for layer in range(n_layers):
            f_data[layer::n_layers] = base_vals

        return f

    # Fallback: Interpolate base values at mesh coordinates
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()
    input_coords = np.linspace(0, mesh.glacier_length, len(input_function.dat.data_ro))
    interp_func = interp1d(input_coords, input_function.dat.data_ro, bounds_error = False, fill_value = 'extrapolate')
    interp_vals = interp_func(base_coords)

    function_space_interp = firedrake.FunctionSpace(
        mesh.mesh, hfamily, 1, vfamily = vfamily, vdegree = vdegree
    )
    f = firedrake.Function(function_space_interp)
    f.dat.data[:] = interp_vals

    function_space_proj = firedrake.FunctionSpace(
        mesh.mesh, hfamily, hdegree, vfamily = vfamily, vdegree = vdegree
    )
    return firedrake.project(f, function_space_proj)

# smooth_extruded_function

In [None]:
def smooth_extruded_function(**kwargs):
    f = kwargs['function']
    mesh = kwargs['mesh']
    sigma = kwargs.get('sigma', None)
    window_meters = kwargs.get('window', None)

    f_flat = icepack.depth_average(f)
    f_data = f_flat.dat.data_ro.copy()

    # Compute dx (spacing in meters) from base mesh
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()
    dx = np.mean(np.diff(base_coords))

    if sigma is None:
        if window_meters is None:
            raise ValueError('Must specify either "sigma" or "window_meters"')
        sigma = window_meters / dx / 2.0  # Approximate: 2σ ≈ full width at half maximum

    smoothed_data = scipy.ndimage.gaussian_filter1d(f_data, sigma = sigma)

    f_flat_smoothed = f_flat.copy(deepcopy = True)
    f_flat_smoothed.dat.data[:] = smoothed_data

    f_smoothed = lift_to_mesh(
        function = f_flat_smoothed,
        mesh = mesh,
        function_space = f.function_space()
    )

    return f_smoothed

## solve_bed

In [None]:
@dataclass
class InversionResult:
    bed: firedrake.Function
    misfits: list
    bed_evolution: list
    surface_evolution: list
    thickness_evolution: list
    velocity_evolution: list


def solve_bed(**kwargs):
    # Inputs
    mesh = kwargs['mesh']
    s_ref = kwargs['surface']
    H_guess = kwargs['thickness_guess']
    u_guess = kwargs['velocity']
    a = kwargs['accumulation']
    A = kwargs['fluidity']
    K = kwargs['K']
    Δt = kwargs['timestep']
    num_years = kwargs['model_time']
    num_iterations = kwargs['num_iterations']
    model = kwargs['model']
    solver = kwargs['solver']
    friction = kwargs['friction']
    window = kwargs.get('smoothing_window', None)
    sigma = kwargs.get('sigma', None)

    # Function space and coordinates
    Q = s_ref.function_space()
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()

    # Initial bed guess
    bed_guess = firedrake.Function(Q).project(s_ref - H_guess)

    # Initialize storage
    misfits = []
    bed_evolution = [bed_guess.dat.data_ro.copy()]
    surface_evolution = []
    velocity_evolution = []
    thickness_evolution = []

    num_timesteps = int(num_years / Δt)

    bed_correction = firedrake.Function(Q)
    surface_misfit = firedrake.Function(Q)

    for iteration in trange(num_iterations):
        bed_mod = bed_guess.copy(deepcopy = True)
        H_mod = firedrake.Function(Q).project(s_ref - bed_mod)
        u_mod = u_guess.copy(deepcopy = True)
        s_mod = s_ref.copy(deepcopy = True)

        for step in range(num_timesteps):
            u_mod = solver.diagnostic_solve(
                velocity = u_mod,
                thickness = H_mod,
                surface = s_mod,
                fluidity = A,
                friction = friction
            )
            H_mod = solver.prognostic_solve(
                Δt,
                thickness = H_mod,
                velocity = u_mod,
                thickness_inflow = H_mod,
                accumulation = a,
            )
            s_mod.project(icepack.compute_surface(bed = bed_mod, thickness = H_mod))
            # s_mod.project(bed_mod + H_mod)

        surface_misfit.project(s_mod - s_ref)
        bed_correction.project(-K * surface_misfit)
        bed_guess.project(bed_mod + bed_correction)

        if window is not None or sigma is not None:
            bed_guess = smooth_extruded_function(
                function = bed_guess,
                mesh = mesh,
                window = window,
                sigma = sigma
            )

        # Store evolution values
        misfits.append(np.linalg.norm(surface_misfit.dat.data_ro))
        bed_evolution.append(bed_guess.dat.data_ro.copy())
        surface_evolution.append(s_mod.dat.data_ro.copy())
        thickness_evolution.append(H_mod.dat.data_ro.copy())
        velocity_lifted = lift_to_mesh(
            function = icepack.depth_average(u_mod),
            mesh = mesh,
            function_space = s_mod.function_space()
        )
        velocity_evolution.append(velocity_lifted.dat.data_ro.copy())

    return InversionResult(
        bed = bed_guess,
        misfits = misfits,
        bed_evolution = bed_evolution,
        surface_evolution = surface_evolution,
        velocity_evolution = velocity_evolution,
        thickness_evolution = thickness_evolution
    )

In [None]:
# !jupyter nbconvert --to script centerflow_extruded.ipynb