# centerflow_extruded

**centerflow_extruded** is a Python module for modeling glacier dynamics along flowlines raised to an extruded mesh, permitting higher-order modeling. It is mostly a direct analogue of the accompanying "centerflow" module. It provides tools with several core functionalities:

1. **ID finder**
    Sometimes we need multiple IDs for a single glacier. It's useful to be able to discern the RGIv6 ID from the RGIv7 ID, for example.

2. **Mesh generation**  
   Construct a 1D finite element mesh along an RGI-defined glacier centerline, and extrude it to include a vertical coordinate, $\zeta$, which ranges from 0 at the base to 1 at the surface.

5. **Mesh trimming**  
   In case of partial DEMs or whatever. 

5. **Data interpolation**  
   Interpolate gridded geospatial datasets (e.g., surface elevation, velocity, surface mass balance) onto the extruded mesh.

6. **Function lifting**  
   Reproject functions back to the extruded mesh after they have been flattened to the base mesh.

7. **Smoothing**  
   Smooth functions defined on an extruded mesh.

8. **Bed inversion**  
   Apply a forward-model-based bed inversion scheme following the approach of [van Pelt et al. (2013)](https://tc.copernicus.org/articles/7/987/2013/), using observed surface elevations to iteratively estimate basal topography.

## Imports


In [None]:
from dataclasses import dataclass
import firedrake
from firedrake.mesh import ExtrudedMeshTopology
import geopandas as gpd
import icepack
import numpy as np
import pandas as pd
from pyproj import Geod
from pathlib import Path
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.crs import CRS
from rasterio.io import MemoryFile
import scipy.ndimage
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d
from shapely.geometry import LineString
from tqdm import trange

## rgi6_from_rgi7 etc.

In [None]:
def rgi6_from_rgi7(**kwargs):
    rgiid = kwargs["rgiid"]
    rgi6_path = kwargs["rgi6_path"]
    rgi7_path = kwargs["rgi7_path"]

    # Load RGI7 outlines and get the matching geometry
    gdf7 = gpd.read_file(rgi7_path)
    outline = gdf7[gdf7['rgi_id'].str.contains(rgiid, regex=False)].geometry.values[0]

    # Load RGI6 outlines and normalize ID column
    gdf6 = gpd.read_file(rgi6_path)
    if "RGIId" in gdf6.columns:
        gdf6 = gdf6[["RGIId", "geometry"]].rename(columns={"RGIId": "rgiid_6"})
    elif "rgi_id" in gdf6.columns:
        gdf6 = gdf6[["rgi_id", "geometry"]].rename(columns={"rgi_id": "rgiid_6"})
    else:
        raise ValueError(f"No RGI ID column found in {rgi6_path}")

    # Project to metric CRS for area calculation
    target_crs = "EPSG:32646"
    gdf6 = gdf6.to_crs(target_crs)
    outline_gdf = gpd.GeoDataFrame(geometry=[outline], crs=gdf7.crs).to_crs(target_crs)

    # Intersect and pick the largest overlap
    inter = gpd.overlay(gdf6, outline_gdf, how="intersection")
    if inter.empty:
        return None
    inter["overlap_area_m2"] = inter.geometry.area
    best_match = inter.sort_values("overlap_area_m2", ascending=False).iloc[0]["rgiid_6"]

    # Keep only the "15.xxxxx" part
    return best_match.split("-", 1)[-1]


def rgi7_from_rgi6(**kwargs):
    rgiid = kwargs["rgiid"]
    rgi6_path = kwargs["rgi6_path"]
    rgi7_path = kwargs["rgi7_path"]

    # Load RGI6 outlines and get the matching geometry
    gdf6 = gpd.read_file(rgi6_path)
    if "RGIId" in gdf6.columns:
        gdf6 = gdf6[["RGIId", "geometry"]].rename(columns={"RGIId": "rgiid_6"})
    elif "rgi_id" in gdf6.columns:
        gdf6 = gdf6[["rgi_id", "geometry"]].rename(columns={"rgi_id": "rgiid_6"})
    else:
        raise ValueError(f"No RGI ID column found in {rgi6_path}")
    outline = gdf6[gdf6['rgiid_6'].str.contains(rgiid, regex=False)].geometry.values[0]

    # Load RGI7 outlines
    gdf7 = gpd.read_file(rgi7_path)[["rgi_id", "geometry"]]

    # Project to metric CRS for area calculation
    target_crs = "EPSG:32646"
    gdf7 = gdf7.to_crs(target_crs)
    outline_gdf = gpd.GeoDataFrame(geometry=[outline], crs=gdf6.crs).to_crs(target_crs)

    # Intersect and pick the largest overlap
    inter = gpd.overlay(gdf7, outline_gdf, how="intersection")
    if inter.empty:
        return None
    inter["overlap_area_m2"] = inter.geometry.area
    best_match = inter.sort_values("overlap_area_m2", ascending=False).iloc[0]["rgi_id"]

    # Keep only the "15-xxxxx" part
    return "-".join(best_match.split("-")[3:])

def latlon_from_rgi7(**kwargs):
    rgiid_7   = kwargs["rgiid"]
    rgi7_path = kwargs["rgi7_path"]

    # Load and filter
    gdf7 = gpd.read_file(rgi7_path)
    match = gdf7[gdf7['rgi_id'].str.contains(rgiid_7, regex=False)]
    if match.empty:
        raise ValueError(f"RGI7 ID '{rgiid_7}' not found in {rgi7_path}")

    # Reproject to a metric CRS for centroid calculation
    metric_crs = "EPSG:32646"  # Bhutan region
    centroid_metric = match.to_crs(metric_crs).geometry.centroid.iloc[0]

    # Convert centroid back to geographic CRS
    centroid_geo = gpd.GeoSeries([centroid_metric], crs=metric_crs).to_crs("EPSG:4326").iloc[0]
    lat, lon = centroid_geo.y, centroid_geo.x

    # Build tile string
    lat_prefix = "N" if lat >= 0 else "S"
    lon_prefix = "E" if lon >= 0 else "W"
    lat_deg = int(np.floor(np.abs(lat)))
    lon_deg = int(np.floor(np.abs(lon)))
    tile = f"{lat_prefix}{lat_deg:02d}{lon_prefix}{lon_deg:03d}"

    return tile

## centerline_mesh

In [None]:
@dataclass
class IntervalMeshResult:
    mesh: firedrake.ExtrudedMesh      # extruded in HO case
    x: np.ndarray                     # lon along centerline (possibly extended/truncated)
    y: np.ndarray                     # lat along centerline
    X: np.ndarray                     # base-mesh chainage coordinates (m)
    basal_coords: np.ndarray          # (X, 0)
    surface_coords: np.ndarray        # (X, 1)
    length: float                     # domain length (m)
    centerline: object                # shapely LineString
    outline: object                   # glacier outline geometry

def centerline_mesh(**kwargs):
    rgiid = kwargs.get('rgiid', '15-09534')
    centerline_path = kwargs.get('centerline_path', None)
    outline_path = kwargs.get('outline_path', None)
    extra_length = float(kwargs.get('extra_length', 0.0))  # may be + (extend) or – (truncate)
    n_cells = int(kwargs['n_cells'])
    prev = kwargs.get('mesh', None)  # optionally reuse outline/centerline from an existing mesh

    # Get outline + centerline in lon/lat
    if prev is None:
        if centerline_path is None or outline_path is None:
            raise ValueError("Provide centerline_path and outline_path (or pass mesh=...).")
        outlines = gpd.read_file(outline_path)
        centerlines = gpd.read_file(centerline_path)
        outline = outlines[outlines['rgi_id'].str.contains(rgiid)].geometry.values[0]
        flowlines = centerlines[centerlines.intersects(outline)]
        centerline = flowlines.loc[flowlines.to_crs('EPSG:32646').length.idxmax(), 'geometry']
    else:
        outline = prev.outline
        centerline = prev.centerline

    # Extract and prepare arrays
    geod = Geod(ellps='WGS84')
    x_ll, y_ll = centerline.xy
    x_ll = np.asarray(x_ll, dtype=float)
    y_ll = np.asarray(y_ll, dtype=float)

    # Cumulative geodesic chainage along original line
    seg = np.array([geod.inv(x_ll[i], y_ll[i], x_ll[i+1], y_ll[i+1])[2] for i in range(len(x_ll) - 1)], dtype=float)
    chainage = np.concatenate(([0.0], np.cumsum(seg)))
    base_len = float(chainage[-1])
    new_len = base_len + extra_length
    if new_len <= 0.0:
        raise ValueError("Resulting length must be > 0.")

    # Extend or truncate geometry to match new_len
    if extra_length > 0.0:
        az, _, _ = geod.inv(x_ll[-2], y_ll[-2], x_ll[-1], y_ll[-1])
        x_new, y_new, _ = geod.fwd(x_ll[-1], y_ll[-1], az, extra_length)
        x_mod = np.append(x_ll, x_new)
        y_mod = np.append(y_ll, y_new)
    elif extra_length < 0.0:
        idx = int(np.searchsorted(chainage, new_len, side='right') - 1)
        rem = new_len - chainage[idx]
        if rem == 0.0:
            x_mod = x_ll[:idx+1]
            y_mod = y_ll[:idx+1]
        else:
            az, _, _ = geod.inv(x_ll[idx], y_ll[idx], x_ll[idx+1], y_ll[idx+1])
            x_cut, y_cut, _ = geod.fwd(x_ll[idx], y_ll[idx], az, rem)
            x_mod = np.r_[x_ll[:idx+1], x_cut]
            y_mod = np.r_[y_ll[:idx+1], y_cut]
    else:
        x_mod, y_mod = x_ll, y_ll

    centerline = LineString(np.column_stack((x_mod, y_mod)))

    # Build meshes over [0, new_len]
    base_mesh = firedrake.IntervalMesh(n_cells, new_len)
    mesh = firedrake.ExtrudedMesh(base_mesh, layers=1)

    X = base_mesh.coordinates.dat.data_ro.flatten()
    basal_coords = np.column_stack((X, np.zeros_like(X)))
    surface_coords = np.column_stack((X, np.ones_like(X)))

    return IntervalMeshResult(
        mesh=mesh,
        x=x_mod,
        y=y_mod,
        X=X,
        basal_coords=basal_coords,
        surface_coords=surface_coords,
        length=new_len,
        centerline=centerline,
        outline=outline,
    )

## crop_mesh

In [None]:
def _cut(line, d):
    if d <= 0.0:
        return [None, LineString(line.coords)]
    if d >= line.length:
        return [LineString(line.coords), None]
    coords, acc = list(line.coords), 0.0
    for i in range(len(coords) - 1):
        p0, p1 = coords[i], coords[i + 1]
        seg = LineString([p0, p1])
        L = seg.length
        if acc + L >= d:
            t = (d - acc) / L
            x = p0[0] + t * (p1[0] - p0[0])
            y = p0[1] + t * (p1[1] - p0[1])
            pt = (x, y)
            return [LineString(coords[:i + 1] + [pt]), LineString([pt] + coords[i + 1:])]
        acc += L
    return [LineString(coords), None]

def _segment(line, d0, d1):
    left, mid = _cut(line, d0)
    mid, right = _cut(mid, d1 - d0)
    return mid

def crop_mesh(**kwargs):
    mesh = kwargs['mesh']
    data_path = kwargs['data_path']

    with rasterio.open(data_path) as src:
        r_crs = src.crs
        nodata = src.nodata

    cl_m = gpd.GeoSeries([mesh.centerline], crs='EPSG:4326').to_crs(r_crs).iloc[0]

    target_samples = 400
    step = max(cl_m.length / target_samples, 25.0)
    dists = np.arange(0.0, cl_m.length + step, step, dtype=float)
    pts = [cl_m.interpolate(float(d)) for d in dists]

    with rasterio.open(data_path) as src:
        vals = np.array([v[0] for v in src.sample([(p.x, p.y) for p in pts])])

    valid = ~np.isnan(vals) if nodata is None else (vals != nodata)
    if not valid.any():
        raise ValueError('No valid data found along centerline for the provided raster.')

    first = int(np.argmax(valid))
    last = int(len(valid) - 1 - np.argmax(valid[::-1]))

    # If no cropping is needed, return unchanged mesh
    if first == 0 and last == len(valid) - 1:
        return mesh

    d0, d1 = float(dists[first]), float(dists[last])
    cl_m_cropped = _segment(cl_m, d0, d1)
    cl_ll = gpd.GeoSeries([cl_m_cropped], crs=r_crs).to_crs('EPSG:4326').iloc[0]

    geod = Geod(ellps='WGS84')
    xs, ys = np.asarray(cl_ll.xy[0]), np.asarray(cl_ll.xy[1])
    segs = [geod.inv(xs[i], ys[i], xs[i + 1], ys[i + 1])[2] for i in range(len(xs) - 1)]
    new_len = float(np.sum(segs))

    n_cells = len(mesh.X) - 1
    base_mesh = firedrake.IntervalMesh(n_cells, new_len)
    extruded = firedrake.ExtrudedMesh(base_mesh, layers=1)
    X = base_mesh.coordinates.dat.data_ro.flatten()
    surface_coords = np.column_stack((X, np.ones_like(X)))
    basal_coords = np.column_stack((X, np.zeros_like(X)))

    return type(mesh)(
        mesh = extruded,
        x = xs,
        y = ys,
        X = X,
        basal_coords = basal_coords,
        surface_coords = surface_coords,
        centerline = cl_ll,
        outline = mesh.outline,
        length = new_len,
    )


## map_to_mesh

In [None]:
def map_to_mesh(**kwargs):
    """
    Map data onto the provided *extruded* mesh.

    Use one of:
      - function=<Firedrake Function>: reproject an existing field, or
      - data_path=<.tif or .csv>: sample along the centerline and interpolate.

    Options:
      mesh: IntervalMeshResult (extruded)
      element: str (default 'CG')
      dimension: int polynomial degree (default 1)
      ice_free_value: constant to apply outside the source x-extent (if provided)
      projection: CRS string for rasters (default 'EPSG:4326')
      key_value, data_value, key_dataset: CSV helpers (as in SIA)
      target_space: optional FunctionSpace to write into (overrides element/dimension)
    """
    mesh = kwargs['mesh']
    data_path = kwargs.get('data_path', None)
    function = kwargs.get('function', None)
    degree = int(kwargs.get('dimension', 1))
    family = kwargs.get('element', 'CG')
    key_value = kwargs.get('key_value', 'n/a')
    data_value = kwargs.get('data_value', 'n/a')
    key_dataset = kwargs.get('key_dataset', None)
    projection = kwargs.get('projection', 'EPSG:4326')
    ice_free_value = kwargs.get('ice_free_value', None)
    target_space = kwargs.get('target_space', None)

    # Base-line vertices (chainage in meters) and basal coords (x, zeta=0)
    X = mesh.X
    basal_coords = mesh.basal_coords

    # Helper: land values on CG(1)×R^0 and (optionally) project to target space/degree
    def _finalize(values_on_vertices):
        V1 = firedrake.FunctionSpace(mesh.mesh, 'CG', 1, vfamily='R', vdegree=0)
        f1 = firedrake.Function(V1)
        f1.dat.data[:] = values_on_vertices
        if target_space is not None:
            return firedrake.project(f1, target_space)
        if degree != 1 or family != 'CG':
            Vt = firedrake.FunctionSpace(mesh.mesh, family, degree, vfamily='R', vdegree=0)
            return firedrake.project(f1, Vt)
        return f1

    # ====================== Branch A: function → mesh (no clamping) ======================
    if function is not None:
        src_mesh = function.function_space().mesh()
        gdim = src_mesh.geometric_dimension()

        # Source horizontal extent
        try:
            x_src = src_mesh._base_mesh.coordinates.dat.data_ro.flatten()
        except AttributeError:
            x_src = src_mesh.coordinates.dat.data_ro.flatten()
        x_min, x_max = float(np.min(x_src)), float(np.max(x_src))
        tol = 1e-10 * max(1.0, x_max - x_min)

        # Mask target points that are inside the source domain
        inside = (X >= x_min - tol) & (X <= x_max + tol)

        if ice_free_value is None and not np.all(inside):
            raise ValueError(
                "map_to_mesh(function=...) targets extend beyond the source function's domain. "
                "Remap the source to the new mesh first (extend_to_mesh), or pass ice_free_value=..."
            )

        # Prepare output; fill with ice_free_value if provided, else zeros (won't be used outside)
        vals = np.empty_like(X, dtype=float)
        if ice_free_value is not None:
            vals[:] = float(ice_free_value)

        # Evaluate only the inside points on the source mesh
        if np.any(inside):
            if gdim == 1:
                vals_inside = np.array(function.at(X[inside], tolerance=1e-10)).reshape(-1)
            else:
                eval_pts = basal_coords[inside].copy()
                vals_inside = np.array(function.at(eval_pts, tolerance=1e-10)).reshape(-1)
            vals[inside] = vals_inside

        return _finalize(vals)

    # ====================== Branch B: file → mesh ======================
    if data_path is None:
        raise ValueError('data_path is required unless you pass function=...')

    extension = Path(data_path).suffix.lower()

    if extension == '.tif':
        x_ll, y_ll = mesh.x, mesh.y

        with rasterio.open(data_path) as src:
            src_crs = src.crs
            tgt_crs = CRS.from_string(projection)

            if src_crs != tgt_crs:
                transform, width, height = calculate_default_transform(
                    src_crs, tgt_crs, src.width, src.height, *src.bounds
                )
                meta = src.meta.copy()
                meta.update({'crs': tgt_crs, 'transform': transform, 'width': width, 'height': height})
                with MemoryFile() as memfile:
                    with memfile.open(**meta) as dst:
                        for i in range(1, src.count + 1):
                            reproject(
                                source=rasterio.band(src, i), destination=rasterio.band(dst, i),
                                src_transform=src.transform, src_crs=src_crs,
                                dst_transform=transform, dst_crs=tgt_crs,
                                resampling=Resampling.bilinear,
                            )
                    with memfile.open() as reproj:
                        sampled = np.array(list(reproj.sample(zip(x_ll, y_ll)))).flatten()
            else:
                sampled = np.array(list(src.sample(zip(x_ll, y_ll)))).flatten()

        geod = Geod(ellps='WGS84')
        chainage = np.insert(np.cumsum([
            geod.inv(x_ll[i], y_ll[i], x_ll[i+1], y_ll[i+1])[2] for i in range(len(x_ll) - 1)
        ]), 0, 0)

        vals = interp1d(chainage, sampled, bounds_error=False, fill_value='extrapolate')(X)
        if ice_free_value is not None:
            vals[X > chainage[-1] + 1e-12] = ice_free_value

        return _finalize(vals)

    elif extension == '.csv':
        df = pd.read_csv(data_path)
        if key_value not in df.columns or data_value not in df.columns:
            raise ValueError(f'CSV must include {key_value!r} and {data_value!r} columns.')
        if key_dataset is None:
            raise ValueError('Must provide "key_dataset" (a Firedrake Function) when using CSV input.')

        key_array = key_dataset.dat.data_ro
        vals = interp1d(df[key_value], df[data_value], bounds_error=False, fill_value='extrapolate')(key_array)

        # Mirror SIA behavior: place result in the key dataset's space
        V_csv = firedrake.FunctionSpace(mesh.mesh, key_dataset.function_space().ufl_element())
        out_csv = firedrake.Function(V_csv)
        out_csv.dat.data[:] = vals
        return out_csv

    else:
        raise ValueError(f'Unsupported file extension: {extension}')


# extend_to_mesh

In [None]:
def extend_to_mesh(**kwargs):
    """
    Map a Firedrake Function onto the (extruded) target mesh.

    Args:
      function: source Firedrake Function
      mesh: IntervalMeshResult (target extruded mesh)
      ice_free_value: optional constant for outside-domain fill
      target_space: optional FunctionSpace to land in (defaults to same element on target mesh)

    Returns:
      Firedrake Function on the target mesh.
    """
    f_src = kwargs['function']
    mesh = kwargs['mesh']
    ice_free_value = kwargs.get('ice_free_value', None)
    target_space = kwargs.get('target_space', None)

    # Pick a target space that lives on the TARGET mesh.
    V_tgt = target_space or firedrake.FunctionSpace(
        mesh.mesh, f_src.function_space().ufl_element()
    )

    # If the source is ALREADY on the target mesh and the same element, just copy.
    if (f_src.function_space().mesh() is V_tgt.mesh()
        and f_src.function_space().ufl_element() == V_tgt.ufl_element()):
        return f_src.copy(deepcopy=True)

    # Otherwise, actually map to the target mesh/space.
    return map_to_mesh(
        mesh=mesh,
        function=f_src,
        target_space=V_tgt,
        ice_free_value=ice_free_value,
    )


# smooth_extruded_function

Important for icepack's HO model, which can be somewhat finicky with convergence, especially when there's a lot of jitter in certain fields. 

In [None]:
def smooth_function(**kwargs):
    """
    Depth-average a scalar field on the extruded mesh, smooth along X with a 1-D Gaussian,
    then lift back into the original function space on the same mesh.

    Args:
      function: Firedrake Function (on extruded mesh)
      mesh: IntervalMeshResult (extruded)
      sigma: optional Gaussian sigma in grid points (overrides window)
      window: optional window width in meters (approx 2*sigma*dx)

    Returns:
      Firedrake Function in the same space as `function`.
    """
    f = kwargs['function']
    mesh = kwargs['mesh']
    sigma = kwargs.get('sigma', None)
    window_m = kwargs.get('window', None)

    # 1) depth-average to 1-D
    f_flat = icepack.depth_average(f)
    data = f_flat.dat.data_ro.copy()

    # 2) choose sigma (in grid points) from requested window (meters)
    X = mesh.X
    dx = float(np.mean(np.diff(X))) if len(X) > 1 else 1.0
    if sigma is None:
        if window_m is None:
            raise ValueError("Provide either 'sigma' (grid pts) or 'window' (meters).")
        sigma = (window_m / dx) / 2.0  # ~ 2σ ≈ FWHM heuristic

    # 3) smooth in 1-D along chainage
    data_s = scipy.ndimage.gaussian_filter1d(data, sigma=float(sigma))

    f_flat_s = f_flat.copy(deepcopy=True)
    f_flat_s.dat.data[:] = data_s

    # 4) lift back to the original space on the same mesh
    return map_to_mesh(mesh=mesh, function=f_flat_s, target_space=f.function_space(), ice_free_value=None)


## solve_bed

In [None]:
@dataclass
class InversionResult:
    bed: firedrake.Function
    misfits: list
    bed_evolution: list
    surface_evolution: list
    thickness_evolution: list
    velocity_evolution: list
    s_ref: firedrake.Function


def solve_bed(**kwargs):
    # Inputs
    mesh = kwargs['mesh']
    s_init = kwargs['surface']
    s_ref = kwargs.get('surface_2', s_init)       
    H_guess = kwargs['thickness_guess']
    u_guess = kwargs['velocity']
    a = kwargs['accumulation']
    A = kwargs['fluidity']
    K = kwargs['K']
    num_iterations = kwargs['num_iterations']
    model = kwargs['model']
    solver = kwargs['solver']
    friction = kwargs['friction']

    try: num_years = s_ref.year - s_init.year #extract the time diff from the DEMs, if applicable
    except: num_years = kwargs['model_time'] #otherwise, need to choose how long to model for 

    try: Δt = round(list(a)[1] - list(a)[0], 10) #extract Δt from the SMB list, if applicable
    except: Δt = kwargs['timestep'] #otherwise, it needs to be specified

    # Function space and coordinates
    Q = s_init.function_space()
    base_coords = mesh.mesh._base_mesh.coordinates.dat.data_ro.flatten()

    # Initial bed guess
    bed_guess = firedrake.Function(Q).project(s_init - H_guess)

    # Initialize storage
    misfits = []
    bed_evolution = [bed_guess.dat.data_ro.copy()]
    surface_evolution = []
    velocity_evolution = []
    thickness_evolution = []

    num_timesteps = int(num_years/Δt)

    bed_correction = firedrake.Function(Q)
    surface_misfit = firedrake.Function(Q)

    for iteration in trange(num_iterations):
        bed_mod = bed_guess.copy(deepcopy = True)
        H_mod = firedrake.Function(Q).project(s_init - bed_mod)
        H_0 = H_mod.copy(deepcopy = True)
        u_mod = u_guess.copy(deepcopy = True)
        s_mod = s_init.copy(deepcopy = True)

        for step in range(num_timesteps):

            try: accumulation = a[s_init.year + step*Δt] #if SMB is a dictionary with date keys
            except: accumulation = a #otherwise

            try:
                u_mod = solver.diagnostic_solve(
                    velocity = u_mod,
                    thickness = H_mod,
                    surface = s_mod,
                    fluidity = A,
                    friction = friction
                )
                
                H_mod = solver.prognostic_solve(
                    Δt,
                    thickness = H_mod,
                    velocity = u_mod,
                    thickness_inflow = H_0,
                    accumulation = accumulation,
                )
                s_mod.project(icepack.compute_surface(bed = bed_mod, thickness = H_mod))

            except:
                print(f'Bed solver failed on step {step + 1} of iteration {iteration + 1}')

                return InversionResult(
                    bed = bed_guess,
                    misfits = misfits,
                    bed_evolution = bed_evolution,
                    surface_evolution = surface_evolution,
                    velocity_evolution = velocity_evolution,
                    thickness_evolution = thickness_evolution,
                    s_ref = s_ref
                )

        surface_misfit.project(s_mod - s_ref)
        bed_correction.project(-K * surface_misfit)
        bed_guess.project(bed_mod + bed_correction)

        # Store evolution values
        misfits.append(float(firedrake.assemble(surface_misfit*firedrake.dx)/mesh.length))
        # misfits.append(np.linalg.norm(surface_misfit.dat.data_ro))
        bed_evolution.append(bed_guess.dat.data_ro.copy())
        surface_evolution.append(s_mod.dat.data_ro.copy())
        thickness_evolution.append(H_mod.dat.data_ro.copy())
        velocity_lifted = map_to_mesh(function = icepack.depth_average(u_mod), mesh = mesh)
        velocity_evolution.append(velocity_lifted.dat.data_ro.copy())

    return InversionResult(
        bed = bed_guess,
        misfits = misfits,
        bed_evolution = bed_evolution,
        surface_evolution = surface_evolution,
        velocity_evolution = velocity_evolution,
        thickness_evolution = thickness_evolution,
        s_ref = s_ref
    )

In [None]:
# !jupyter nbconvert --to script centerflow_extruded.ipynb