In [2]:
#| default_exp adjacency

In [3]:
#| export

import numpy as np
import igl

In [None]:
#| export

import jax
import jax.numpy as jnp
import jax.experimental.sparse as jsparse

In [5]:
#| export

from jaxtyping import Float 

In [None]:
#| export

from triangulax import mesh as msh

In [9]:
#| hide

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", False)
jax.config.update('jax_log_compiles', False) # use this to log JAX JIT compilations.

In [10]:
#| hide

import jaxtyping


In [11]:
#| hide

%load_ext jaxtyping 
%jaxtyping.typechecker beartype.beartype

# enables type checking. does not work for some cells (vmapping and loading/saving). For those, %unload_ext jaxtyping 


## Adjacency-like operators on half-edge meshes

Using the `HeMesh` data structure, we can efficiently "traverse" our mesh. Using such traversals, one can express many adjacency-based _linear operators_, for example:

- Sum over all half-edges "incoming" to a vertex (special case: _count_ the incoming edges, i.e., compute the coordination number) 
- Compute the finite-element gradient of a function defined on vertices

These operations can be done efficiently using a "gather/scatter" approach, see [`jax.numpy.ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). There is no need to explicitly instantiate a matrix for the operators.

All operators defined in this notebook depend only on the mesh topology, not the geometry (vertex/face positions)

In [14]:
from triangulax.triangular import TriMesh

In [15]:
# load test data

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)

mesh_3d = TriMesh.read_obj("test_meshes/disk.obj", dim=3)
geommesh_3d = msh.GeomMesh(*hemesh.n_items, mesh_3d.vertices, mesh_3d.face_positions)

  o flat_tri_ecmc
  o flat_tri_ecmc


### Discrete derivative

On a triangular mesh, there are two natural "derivatives": for a per-vertex field, the difference across half-edges, and for a per-half-edge field, the circulation around a face (this is the basis of [discrete exterior calculus](https://en.wikipedia.org/wiki/Discrete_exterior_calculus)).

In [43]:
#| export

def get_exterior_gradient(hemesh: msh.HeMesh, v_field: Float[jax.Array, "n_vertices ..."]) -> Float[jax.Array, "n_hes ..."]:
    return v_field[hemesh.dest] - v_field[hemesh.orig]

def get_exterior_circulation(hemesh: msh.HeMesh, he_field: Float[jax.Array, "n_hes ..."]) -> Float[jax.Array, "n_faces ..."]:
    return he_field[hemesh.prv[hemesh.face_incident]]+he_field[hemesh.face_incident]+he_field[hemesh.nxt[hemesh.face_incident]] 

In [44]:
# define a random scalar field on vertices and compute its gradient on halfedges
v_field = jax.random.uniform(jax.random.PRNGKey(0), (hemesh.n_vertices,))
he_gradient = get_exterior_gradient(hemesh, v_field)
f_circulation = get_exterior_circulation(hemesh, he_gradient)

hemesh, he_gradient.shape, f_circulation.shape, jnp.allclose(f_circulation, 0 )

(HeMesh(N_V=131, N_HE=708, N_F=224), (708,), (224,), Array(True, dtype=bool))

### Summing over adjacent mesh elements

A second important class of operation is summing over adjacent mesh elements. For example, to get the coordination number of a vertex, you want to sum the value $1$ over all incoming half-edges. For computing things like cell areas, it's also useful to sum over half-edges _opposite_ to a vertex.

In [47]:
#| export

def sum_he_to_vertex_incoming(hemesh: msh.HeMesh, he_field: Float[jax.Array, "n_hes ..."]
                              ) -> Float[jax.Array, "n_vertices ..."]:
    """
    Sum a half-edge field onto destination vertices.

    hemesh: connectivity information
    he_field: (n_hes,) or (n_hes, d) array
    """
    out_shape = (hemesh.n_vertices,) + he_field.shape[1:]  # supports scalar or vector fields
    v_field = jnp.zeros(out_shape, dtype=he_field.dtype)
    # Scatter-add: for each half-edge h, add he_field[h] to v_field[dest[h]]
    return v_field.at[hemesh.dest].add(he_field)

def sum_he_to_vertex_opposite(hemesh: msh.HeMesh, he_field: Float[jax.Array, "n_hes ..."]
                              ) -> Float[jax.Array, "n_vertices ..."]:
    """
    Sum a half-edge field onto opposite vertices.

    Attention: can include boundary half-edges!

    hemesh: connectivity information
    he_field: (n_hes,) or (n_hes, d) array
    """
    out_shape = (hemesh.n_vertices,) + he_field.shape[1:]  # supports scalar or vector fields
    v_field = jnp.zeros(out_shape, dtype=he_field.dtype)
    return v_field.at[hemesh.dest[hemesh.nxt]].add(he_field)


In [48]:
#| export

def sum_he_to_face(hemesh: msh.HeMesh, he_field: Float[jax.Array, "n_hes ..."]
                  ) -> Float[jax.Array, "n_faces ..."]:
    """Sum over all half-edges of a face. Alias of get_exterior_circulation."""
    return he_field[hemesh.prv[hemesh.face_incident]]+he_field[hemesh.face_incident]+he_field[hemesh.nxt[hemesh.face_incident]] 

def sum_face_to_he(hemesh: msh.HeMesh, f_field: Float[jax.Array, "n_faces ..."]
                  ) -> Float[jax.Array, "n_hes ..."]:
    """Sum face-field to half-edges. Sums over the face of the half-edge and its twin."""
    return f_field[hemesh.face_incident] + f_field[hemesh.face_incident[hemesh.twin]]

In [None]:
#| export

def sum_vertex_to_face(hemesh: msh.HeMesh, v_field: Float[jax.Array, "n_vertices ..."]
                  ) -> Float[jax.Array, "n_faces ..."]:
    """Sum vertex-field to faces. Sums over the vertices of the face."""
    return (v_field[hemesh.orig[hemesh.face_incident]]
            + v_field[hemesh.dest[hemesh.face_incident]]
            + v_field[hemesh.dest[hemesh.nxt[hemesh.face_incident]]])

def average_vertex_to_face(hemesh: msh.HeMesh, v_field: Float[jax.Array, "n_vertices ..."]
                  ) -> Float[jax.Array, "n_faces ..."]:
    """Average vertex-field to faces."""
    return sum_vertex_to_face(hemesh, v_field) / 3

def sum_face_to_vertex(hemesh: msh.HeMesh, f_field: Float[jax.Array, "n_faces ..."]
                      ) -> Float[jax.Array, "n_vertices ..."]:
    """Sum face-field to vertices. Sums over the faces incident on the vertex."""
    out_shape = (hemesh.n_vertices,) + f_field.shape[1:]  # supports scalar or vector fields
    v_field = jnp.zeros(out_shape, dtype=f_field.dtype)
    v_field = v_field.at[hemesh.orig[hemesh.face_incident]].add(f_field)
    v_field = v_field.at[hemesh.dest[hemesh.face_incident]].add(f_field)
    v_field = v_field.at[hemesh.dest[hemesh.nxt[hemesh.face_incident]]].add(f_field)

    return v_field

def average_face_to_vertex(hemesh: msh.HeMesh, f_field: Float[jax.Array, "n_faces ..."]
                           ) -> Float[jax.Array, "n_vertices ..."]:
    """Average face-field to vertices. Uniform weights."""
    summed_field =  sum_face_to_vertex(hemesh, f_field)
    weights = sum_face_to_vertex(hemesh, jnp.ones(hemesh.n_faces))
    return (summed_field.T / weights.T).T # .T for broadcasting

In [97]:
# tests vs libigl
key = jax.random.PRNGKey(123)

u_v = jax.random.normal(key, (hemesh.n_vertices,))
faces_avg_jax = average_vertex_to_face(hemesh, u_v)
faces_avg_igl = igl.average_onto_faces(np.asarray(hemesh.faces), np.asarray(u_v))

rel_err_faces = jnp.linalg.norm(faces_avg_jax - faces_avg_igl) / jnp.linalg.norm(faces_avg_igl)
print("vertex->face rel. error:", rel_err_faces)

vertex->face rel. error: 0.0


In [102]:
u_f = jax.random.normal(key, (hemesh.n_faces,))
verts_avg_jax = average_face_to_vertex(hemesh, u_f)
verts_avg_igl = igl.average_onto_vertices(mesh.vertices, np.asarray(hemesh.faces), np.asarray(u_f))
rel_err_verts = jnp.linalg.norm(verts_avg_jax-verts_avg_igl) / jnp.linalg.norm(verts_avg_igl)

print("face->vertex rel. error:", rel_err_verts)


face->vertex rel. error: 8.339340577730768e-17


In [105]:
# also works for vector fields
u_f = jax.random.normal(key, (hemesh.n_faces, 10))
verts_avg_jax = average_face_to_vertex(hemesh, u_f)
verts_avg_jax.shape


(131, 10)

In [50]:
#| export

def get_coordination_number(hemesh: msh.HeMesh) -> Float[jax.Array, " n_vertices"]:
    return sum_he_to_vertex_incoming(hemesh, jnp.ones(hemesh.n_hes))

In [51]:
get_coordination_number(hemesh).mean()

Array(5.40458015, dtype=float64)

### Uniform/graph Laplacian

In [73]:
#| export

def compute_uniform_laplacian(hemesh: msh.HeMesh, v_field: Float[jax.Array, "n_vertices ..."]
                              ) -> Float[jax.Array, "n_vertices ..."]:
    """Computes the uniform Laplacian of a vector field. Non-normalized, positive definite."""
    he_gradient = get_exterior_gradient(hemesh, v_field)
    return sum_he_to_vertex_incoming(hemesh, he_gradient)

def get_uniform_laplacian(hemesh: msh.HeMesh) -> jsparse.BCOO:
    """Returns the uniform Laplacian matrix as a sparse matrix. Non-normalized, positive definite."""
    row = hemesh.dest
    col = hemesh.orig
    data = -jnp.ones(hemesh.n_hes)
    data_diagonal = get_coordination_number(hemesh)
    row_diagonal, col_diagonal = jnp.arange(hemesh.n_vertices), jnp.arange(hemesh.n_vertices)
    data = jnp.concatenate([data, data_diagonal])
    row, col = jnp.concatenate([row, row_diagonal]), jnp.concatenate([col, col_diagonal])
    L = jsparse.BCOO((data, jnp.stack([row, col], axis=1)), shape=(hemesh.n_vertices, hemesh.n_vertices))
    return L

In [76]:
# test that the matrix and function versions are equivalent

laplace_mat = get_uniform_laplacian(hemesh)

jnp.allclose(laplace_mat @ v_field, compute_uniform_laplacian(hemesh, v_field)), jnp.dot(laplace_mat@v_field, v_field) > 0

(Array(True, dtype=bool), Array(True, dtype=bool))