In [1]:
#| default_exp linops

In [2]:
#| export

import numpy as np
import igl

In [3]:
#| export

import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsparse

import lineax

import functools

In [4]:
#| export

from jaxtyping import Float 

In [5]:
#| export

from triangulax import trigonometry as trig
from triangulax import mesh as msh
from triangulax import adjacency as adj
from triangulax import geometry as geom

In [6]:
#| hide

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", False)
jax.config.update('jax_log_compiles', False) # use this to log JAX JIT compilations.

In [7]:
#| hide

import jaxtyping


In [8]:
#| hide

%load_ext jaxtyping 
%jaxtyping.typechecker beartype.beartype

# enables type checking. does not work for some cells (vmapping and loading/saving). For those, %unload_ext jaxtyping 


## Finite-element gradient and cotan-Laplacian

Building on the mesh geometry and adjancency-based operators, we can now define two important linear operators that depend both on mesh connectivity and on mesh geometry. They are the (discrete, triangulation-based) equivalent of the gradient and Laplace-Beltrami operator. The latter is known as the _cotan Laplacian_. 

We now implement gradient (per-vertex scalar field -> per-face vector field) and the cotan-Laplacian (vertex -> vertex) using gather/scatter ops. In both cases, we start with a scalar field $u_i$ defined per vertex $i$ of the triangulation. The finite-element gradient is defined for each face $ijk$, like so:
$$ 
(\nabla u)_{ijk} = \sum_{l\in \{i,j,k\}} u_l \nabla\phi_l
$$
where $\phi_i$ is a linear finite element test function (linear Lagrange element) and has gradient
$$
    \nabla\phi_i = \frac{1}{2a_{ijk}} (\mathbf{v}_k-\mathbf{v}_j)^\perp
$$
plus cyclic permutations. Here, $a_{ijk}$ is the triangle area, $\mathbf{v}_i$ are the vertex positions, and $()^\perp$ denotes rotation by 90 degrees (in 3D, you rotate about the triangle normal).

The cot-Laplacian computes the following per-vertex field:
$$
(\Delta u)_i = \frac{1}{2} \sum_{j} (\cot\alpha_j +\cot\beta_j) (u_j-u_i)
$$ 
where the sum is over adjacent vertices, and $\alpha_j, \beta_j$ are the two triangle angles "opposite" to the edge $ij$.

To check for correctness, we can compare with [this `libgigl` tutorial](https://libigl.github.io/libigl-python-bindings/tut-chapter1/), using the test mesh and some random test fields.

In [9]:
from triangulax.triangular import TriMesh

In [62]:
# load test data

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)

mesh_3d = TriMesh.read_obj("test_meshes/disk.obj", dim=3)
geommesh_3d = msh.GeomMesh(*hemesh.n_items, mesh_3d.vertices, mesh_3d.face_positions)

  o flat_tri_ecmc
  o flat_tri_ecmc


### Cotan-Laplacian

In [41]:
#| export

def compute_cotan_laplace(vertices: Float[jax.Array, "n_vertices dim"], hemesh: msh.HeMesh,
                          vertex_field: Float[jax.Array, "n_vertices ..."]
                          ) -> Float[jax.Array, "n_vertices ..."]:
    """
    Compute cotangent laplacian of a per-vertex field (natural boundary conditions).
    """
    w_edge = geom.get_cotan_weights_per_egde(vertices, hemesh)
    diff = vertex_field[hemesh.dest] - vertex_field[hemesh.orig]
    return -adj.sum_he_to_vertex_incoming(hemesh, (w_edge*diff.T).T)

In [42]:
#| export

def cotan_laplace_sparse(vertices: Float[jax.Array, "n_vertices dim"], hemesh: msh.HeMesh
                         ) -> jsparse.BCOO:
    """Assemble cotangent Laplacian as a sparse matrix (BCOO)."""

    w_edge = geom.get_cotan_weights_per_egde(vertices, hemesh)
    unique = hemesh.is_unique

    i = hemesh.orig[unique]
    j = hemesh.dest[unique]
    w = w_edge[unique]

    rows = jnp.concatenate([i, j, i, j])
    cols = jnp.concatenate([j, i, i, j])
    data = jnp.concatenate([w, w, -w, -w])

    mat = jsparse.BCOO((data, jnp.stack([rows, cols], axis=1)),
                       shape=(hemesh.n_vertices, hemesh.n_vertices))
    return mat.sum_duplicates()

In [43]:
# Test against libigl cotmatrix (natural boundary conditions)
key = jax.random.PRNGKey(0)
u = jax.random.normal(key, (hemesh.n_vertices,))
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))

L = igl.cotmatrix(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))

lap_jax = compute_cotan_laplace(geommesh.vertices, hemesh, u)
lap_igl = L @ np.asarray(u)

rel_err = np.linalg.norm(np.asarray(lap_jax) - lap_igl) / np.linalg.norm(lap_igl)
print("scalar field rel. error:", rel_err)

lap_jax_vec = compute_cotan_laplace(geommesh.vertices, hemesh, u_vec)
lap_igl_vec = L @ np.asarray(u_vec)

rel_err_vec = np.linalg.norm(np.asarray(lap_jax_vec) - lap_igl_vec) / np.linalg.norm(lap_igl_vec)
print("vector field rel. error:", rel_err_vec)

scalar field rel. error: 1.7692000627878292e-16
vector field rel. error: 2.011929541056845e-16


In [44]:
# test sparse cotan Laplacian vs apply function
key = jax.random.PRNGKey(0)
u_test = jax.random.normal(key, (hemesh.n_vertices,))

L_sparse = cotan_laplace_sparse(geommesh.vertices, hemesh)
lap_sparse = L_sparse @ u_test
lap_apply = compute_cotan_laplace(geommesh.vertices, hemesh, u_test)

rel_err_sparse = jnp.linalg.norm(lap_sparse - lap_apply) / jnp.linalg.norm(lap_apply)
print("cotan sparse vs apply rel. error:", rel_err_sparse)

cotan sparse vs apply rel. error: 1.302894564211555e-16


### Finite-element gradient

Not to be confused with the discrete-exterior-calculus operators, which only depend on mesh connectivity, not geometry.

In [116]:
#| export

def _fe_grad_phi_2d(vertices: Float[jax.Array, "n_vertices 2"], hemesh: msh.HeMesh,
                 ) -> Float[jax.Array, "n_faces 3 2"]:
    """Per-face gradients of the P1 hat functions (2D).

    For each face f=(i,j,k), returns an array grads[f, l, :] = ∇φ_l, with l=0,1,2
    corresponding to vertices (i,j,k). Degenerate faces get zero gradients.
    """
    faces = hemesh.faces
    v0, v1, v2 = (vertices[faces[:, 0]], vertices[faces[:, 1]], vertices[faces[:, 2]])

    area2 = jnp.cross(v1 - v0, v2 - v0)[:, None]
    mask = jnp.abs(area2) > 1e-12
    grad_phi0 = jnp.where(mask, trig.get_perp_2d(v1 - v2)/area2, 0)
    grad_phi1 = jnp.where(mask, trig.get_perp_2d(v2 - v0)/area2, 0)
    grad_phi2 = jnp.where(mask, trig.get_perp_2d(v0 - v1)/area2, 0)

    return jnp.stack([grad_phi0, grad_phi1, grad_phi2], axis=1)


def _fe_grad_phi_3d(vertices: Float[jax.Array, "n_vertices 3"], hemesh: msh.HeMesh,
                 ) -> Float[jax.Array, "n_faces 3 3"]:
    """Per-face gradients of the P1 hat functions (3D).

    For each face f=(i,j,k), returns grads[f, l, :] = ∇φ_l in R^3, l=0,1,2.
    Degenerate faces get zero gradients.
    """
    faces = hemesh.faces
    v0, v1, v2 = (vertices[faces[:, 0]], vertices[faces[:, 1]], vertices[faces[:, 2]])

    n = jnp.cross(v1 - v0, v2 - v0)
    area2 = jnp.linalg.norm(n, axis=-1)[:, None]**2
    mask = area2 > 1e-12
    
    mask = jnp.abs(area2) > 1e-12
    grad_phi0 = jnp.where(mask, jnp.cross(v1 - v2, n)/area2, 0)
    grad_phi1 = jnp.where(mask, jnp.cross(v2 - v0, n)/area2, 0)
    grad_phi2 = jnp.where(mask, jnp.cross(v0 - v1, n)/area2, 0)

    return jnp.stack([grad_phi0, grad_phi1, grad_phi2], axis=1)


def compute_gradient_2d(vertices: Float[jax.Array, "n_vertices 2"], hemesh: msh.HeMesh,
                        vertex_field: Float[jax.Array, "n_vertices ..."]
                        ) -> Float[jax.Array, "n_faces 2 ..."]:
    """Compute the linear finite-element gradient (constant per face)."""
    faces = hemesh.faces
    grads = _fe_grad_phi_2d(vertices, hemesh)
    vals = vertex_field[faces]
    return jnp.einsum("fvd,fv...->fd...", grads, vals)


def compute_gradient_3d(vertices: Float[jax.Array, "n_vertices 3"], hemesh: msh.HeMesh,
                        vertex_field: Float[jax.Array, "n_vertices ..."]
                        ) -> Float[jax.Array, "n_faces 3 ..."]:
    """Compute the linear finite-element gradient (constant per face)."""
    faces = hemesh.faces
    grads = _fe_grad_phi_3d(vertices, hemesh)
    vals = vertex_field[faces]
    return jnp.einsum("fvd,fv...->fd...", grads, vals)

In [117]:
#| export

def gradient_sparse_2d(vertices: Float[jax.Array, "n_vertices 2"], hemesh: msh.HeMesh,
                      ) -> jsparse.BCOO:
    """Assemble FE gradient in 2D as a sparse matrix (BCOO).

    Returns a matrix G with shape (2*n_faces, n_vertices) such that for a scalar
    per-vertex field u (n_vertices,), the per-face gradients are obtained via:
        g_flat = G @ u                    # (2*n_faces,)
        g = g_flat.reshape((2, n_faces)).T  # (n_faces, 2)

    This row layout matches libigl's `grad` operator convention (component blocks).
    """
    faces = hemesh.faces.astype(jnp.int32)
    n_faces = faces.shape[0]
    n_vertices = hemesh.n_vertices

    grads = _fe_grad_phi_2d(vertices, hemesh)  # (n_faces, 3, 2)
    # order contributions as (component c, corner k, face f)
    data = jnp.transpose(grads, (2, 1, 0)).reshape((-1,))  # (2*3*n_faces,)
    cols_kf = faces.T.reshape((-1,))  # (3*n_faces,), order (k,f)
    cols = jnp.tile(cols_kf, (2,))
    rows_f = jnp.arange(n_faces, dtype=jnp.int32)
    rows_kf = jnp.tile(rows_f, (3,))  # (3*n_faces,), order (k,f)
    rows = jnp.tile(rows_kf, (2,)) + jnp.repeat(jnp.arange(2, dtype=jnp.int32) * n_faces, 3 * n_faces)

    indices = jnp.stack([rows, cols], axis=1)
    return jsparse.BCOO((data, indices), shape=(2 * n_faces, n_vertices))


def gradient_sparse_3d(vertices: Float[jax.Array, "n_vertices 3"], hemesh: msh.HeMesh,
                      ) -> jsparse.BCOO:
    """Assemble FE gradient in 3D as a sparse matrix (BCOO).

    Returns a matrix G with shape (3*n_faces, n_vertices) such that for a scalar
    per-vertex field u (n_vertices,), the per-face gradients are obtained via:
        g_flat = G @ u                    # (3*n_faces,)
        g = g_flat.reshape((3, n_faces)).T  # (n_faces, 3)

    This row layout matches libigl's `grad` operator convention (component blocks).
    """
    faces = hemesh.faces.astype(jnp.int32)
    n_faces = faces.shape[0]
    n_vertices = hemesh.n_vertices

    grads = _fe_grad_phi_3d(vertices, hemesh)  # (n_faces, 3, 3)
    data = jnp.transpose(grads, (2, 1, 0)).reshape((-1,))  # (3*3*n_faces,)
    cols_kf = faces.T.reshape((-1,))  # (3*n_faces,)
    cols = jnp.tile(cols_kf, (3,))
    rows_f = jnp.arange(n_faces, dtype=jnp.int32)
    rows_kf = jnp.tile(rows_f, (3,))
    rows = jnp.tile(rows_kf, (3,)) + jnp.repeat(jnp.arange(3, dtype=jnp.int32) * n_faces, 3 * n_faces)

    indices = jnp.stack([rows, cols], axis=1)
    return jsparse.BCOO((data, indices), shape=(3 * n_faces, n_vertices))

In [122]:
#| export

def reshape_face_gradient(grad_flat: Float[jax.Array, "dim_n_faces ..."], n_faces: int, dim: int,
                          ) -> Float[jax.Array, "n_faces dim ..."]:
    """Reshape a flattened FE gradient into per-face vectors.

    This is meant to be used with `gradient_sparse_2d/3d` (and any similar operator that
    stacks components in blocks), where applying the sparse matrix yields an array of shape
    `(dim*n_faces, ...)` (for scalar/vector/tensor per-vertex fields).

    Parameters
    ----------
    grad_flat
        Output of `G @ u`, with shape `(dim*n_faces, ...)`.
    n_faces
        Number of mesh faces.
    dim
        Spatial dimension (2 or 3).

    Returns
    -------
    grad
        Reshaped gradient with shape `(n_faces, dim, ...)`, matching the output convention
        of `compute_gradient_2d/3d`.
    """
    if grad_flat.shape[0] != dim * n_faces:
        raise ValueError(f"Expected grad_flat.shape[0] == dim*n_faces = {dim*n_faces}, got {grad_flat.shape[0]}")
    grad = jnp.reshape(grad_flat, (dim, n_faces) + grad_flat.shape[1:])
    return jnp.swapaxes(grad, 0, 1)

In [118]:
# here's how to compute the gradient in libigl

grad_matrix = igl.grad(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))
# calculate the gradient of field by matrix multiplication
grad_igl = grad_matrix @ np.asarray(u)
# order='F' copied from igl tutorial
grad_igl = grad_igl.reshape((hemesh.n_faces, geommesh.dim), order='F')

In [119]:
# test jax and libigl implementations

grad_jax = compute_gradient_2d(geommesh.vertices, hemesh, u)

rel_err_grad = np.linalg.norm(np.asarray(grad_jax) - grad_igl) / np.linalg.norm(grad_igl)
print("gradient rel. error:", rel_err_grad)

gradient rel. error: 1.413315746703021e-16


In [120]:
# same test, in 3d

grad_matrix_3d = igl.grad(np.asarray(geommesh_3d.vertices), np.asarray(hemesh.faces))
grad_igl_3d = grad_matrix_3d @ np.asarray(u)
grad_igl_3d = grad_igl_3d.reshape((hemesh.n_faces, geommesh_3d.dim), order='F')

grad_jax_3d = compute_gradient_3d(geommesh_3d.vertices, hemesh, u)

rel_err_grad_3d = np.linalg.norm(np.asarray(grad_jax_3d) - grad_igl_3d) / np.linalg.norm(grad_igl_3d)
print("gradient rel. error:", rel_err_grad_3d)

gradient rel. error: 1.5657863888820882e-16


In [123]:
# Test sparse gradient operators vs apply functions
key = jax.random.PRNGKey(123)
u_test = jax.random.normal(key, (hemesh.n_vertices,))

G2 = gradient_sparse_2d(geommesh.vertices, hemesh)
g2 = reshape_face_gradient(G2 @ u_test, hemesh.n_faces, dim=2)
g2_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_test)
rel_err_g2 = jnp.linalg.norm(g2 - g2_apply) / jnp.linalg.norm(g2_apply)
print("2D grad sparse vs apply rel. error:", rel_err_g2)

G3 = gradient_sparse_3d(geommesh_3d.vertices, hemesh)
g3 = reshape_face_gradient(G3 @ u_test, hemesh.n_faces, dim=3)
g3_apply = compute_gradient_3d(geommesh_3d.vertices, hemesh, u_test)
rel_err_g3 = jnp.linalg.norm(g3 - g3_apply) / jnp.linalg.norm(g3_apply)
print("3D grad sparse vs apply rel. error:", rel_err_g3)

# quick sanity check for vector/tensor fields: u has extra axes
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))
g2_vec = reshape_face_gradient(G2 @ u_vec, hemesh.n_faces, dim=2)
g2_vec_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_vec)
rel_err_g2_vec = jnp.linalg.norm(g2_vec - g2_vec_apply) / jnp.linalg.norm(g2_vec_apply)
print("2D grad (vector field) sparse vs apply rel. error:", rel_err_g2_vec)

2D grad sparse vs apply rel. error: 8.71017994729607e-17
3D grad sparse vs apply rel. error: 9.602668379845331e-17
2D grad (vector field) sparse vs apply rel. error: 8.285943150518157e-17


### Wrapping as linear operators

It's often useful to think of functions like `compute_cotan_laplace()` as a linear operator on fields on meshes. For example, imagine you want to solve the Laplace equation on a mesh with fixed vertex positions and connectivity. You will want to use a linear solver. Luckily, most such solvers only need to be able to compute the action of a linear operator on an input vector, and don't need an explicit matrix representation. 

In the JAX ecosystem, the `lineax` library defines linear solvers. We can _wrap_  `compute_cotan_laplace()` as a linear operator, which allows us to pass it into iterative linear algebra algorithms.

In [49]:
#| export

def scipy_to_bcoo(A) -> jsparse.BCOO:
    """
    Convert a SciPy sparse matrix (CSC or CSR) to a JAX BCOO sparse matrix
    without converting to dense.

    Parameters
    ----------
    A : scipy.sparse.spmatrix
        Input sparse matrix (CSR or CSC recommended)

    Returns
    -------
    B : jax.experimental.sparse.BCOO
        Equivalent JAX sparse matrix
    """
    # Convert to COO
    Acoo = A.tocoo()

    # COO format gives us row, col, data arrays directly
    rows = jnp.array(Acoo.row, dtype=jnp.int32)
    cols = jnp.array(Acoo.col, dtype=jnp.int32)
    data = jnp.array(Acoo.data)
    return jsparse.BCOO((data, jnp.stack([rows, cols], axis=1)), shape=Acoo.shape)


def diag_jsparse(v : Float[jax.Array, " N"], k: int =0) -> jsparse.BCOO:
    """Construct a diagonal jax.sparse array. Plugin replacement for np.diag"""
    N  = v.shape[0] + jnp.abs(k)
    if k >=0:
        row_inds = jnp.arange(k, N, dtype=jnp.int32)
    else:
        row_inds = jnp.arange(0, N+k, dtype=jnp.int32)
    return jsparse.BCOO((v, jnp.stack([row_inds-k, row_inds,], axis=1)), shape=(N, N))

In [50]:
%unload_ext jaxtyping 


In [54]:
# "bake in" the connectivity and vertex positions

laplace_op = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
_ = laplace_op(u) # you can apply this to vertex-fields

# define the linear operator
laplace_op_lx = lineax.FunctionLinearOperator(laplace_op, input_structure=jax.eval_shape(laplace_op, u))

# now you can use the linear operator to compute matrix representations, solve linear systems, etc.
mat = laplace_op_lx.as_matrix()
mat.shape

(131, 131)

In [55]:
#| export

def linear_op_to_sparse(op: callable,
                        in_shape: tuple[int, ...],
                        out_shape: tuple[int, ...],
                        dtype: jnp.dtype | None = None,
                        chunk_size: int = 256,
                        tol: float = 0.0,
                        ) -> jsparse.BCOO:
    """Build a sparse matrix for a linear map using batched one-hot probes.

    Note: this function is general, but not necessarily very efficient for large matrix sizes.
    """
    if len(in_shape) != 1 or len(out_shape) != 1:
        raise ValueError("Only 1D input/output supported for now.")
    n_in = in_shape[0]
    n_out = out_shape[0]
    if dtype is None:
        dtype = jnp.result_type(op(jnp.zeros((n_in,))))

    data_list: list[np.ndarray] = []
    row_list: list[np.ndarray] = []
    col_list: list[np.ndarray] = []

    for start in range(0, n_in, chunk_size):
        end = min(start + chunk_size, n_in)
        idx = jnp.arange(start, end, dtype=jnp.int64)
        basis = jax.nn.one_hot(jnp.array(idx), n_in, dtype=dtype)
        cols = jax.vmap(op)(basis)  # (chunk, n_out)
        #cols = apply_op(basis)
        mask = jnp.abs(cols) > tol
        col_in_batch, row_out = jnp.nonzero(mask)
        if col_in_batch.size == 0:
            continue

        data_list.append(cols[col_in_batch, row_out])
        row_list.append(row_out)
        col_list.append(idx[col_in_batch])

    if len(data_list) == 0:
        return jsparse.empty((n_out, n_in)) 
    data = jnp.concatenate(data_list)
    rows = jnp.concatenate(row_list)
    cols = jnp.concatenate(col_list)
    indices = jnp.stack([jnp.array(rows, dtype=jnp.int32),
                            jnp.array(cols, dtype=jnp.int32)], axis=1)
    return jsparse.BCOO((data, indices), shape=(n_out, n_in))

In [57]:
# compare sparse construction to lineax dense matrix (small meshes only)
if hemesh.n_vertices <= 2000:
    laplace_op_local = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
    laplace_op_lx_local = lineax.FunctionLinearOperator(laplace_op_local,
                                                        input_structure=jax.eval_shape(laplace_op_local, u))
    sp_mat = linear_op_to_sparse(laplace_op_local, (hemesh.n_vertices,), (hemesh.n_vertices,))
    mat_dense = laplace_op_lx_local.as_matrix()
    rel_err_sparse = jnp.linalg.norm(sp_mat.todense() - mat_dense) / jnp.linalg.norm(mat_dense)
    print("sparse vs lineax rel. error:", rel_err_sparse)
else:
    print("Skipping dense comparison for large mesh.")

sparse vs lineax rel. error: 0.0


In [60]:
## now let's try with a large mesh

mesh = TriMesh.read_obj("test_meshes/torus_high_resolution.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)

laplace_op = jax.jit(functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh))

  o Torus


In [61]:
sparse_laplace_op = linear_op_to_sparse(laplace_op, (hemesh.n_vertices,), (hemesh.n_vertices,))