# triangulax

> JAX-compatible triangular meshes and triangular-mesh-based simulations


This package provides data-structures for triangular meshes and geometry processing tools based on JAX and fully compatible with JAX's just-in-time compilation and automatic differentiation.

### Use cases

Triangular meshes are ubiquitous in computer graphics and in scientific computing. The tools provided by `triangulax` makes it easy to implement custom geometry processing operations (in familiar python). But the main feature of `triangulax` is compatibility with automatic differentiation. Why? Often, one is often interested in deforming/dynamical meshes. For example:
1. Flattening or deforming 3D models (computer graphics)
2. Simulating thin plates or membranes (mechanics) 

This usually requires optimizing some mesh-based "energy" (like the [Helfrich energy](https://en.wikipedia.org/wiki/Elasticity_of_cell_membranes) in membrane elasticity, or the [Dirichlet energy](https://multires.caltech.edu/pubs/ConfEquiv.pdf) in mesh parametrization). Automatic differentiation makes is trivial to compute gradients and optimize such energies, and rapidly explore different ideas.

Another use case is as a simulation framework for 2D tissue mechanics ([Active tension networks](https://www.pnas.org/doi/10.1073/pnas.2321928121), [area-perimeter vertex model](https://journals.aps.org/prx/abstract/10.1103/PhysRevX.6.021011), [active foams](https://www.nature.com/articles/s41567-021-01215-1)). 2D cell tilings are conveniently represented by their dual triangular network.
`triangulax` makes it easy to implement various models, and JAX makes it trivial to compute the resulting forces, and run simulations.

#### Gradient-based "meta-optimization"

However, we can go further: once we have written a dynamical model for triangular meshes, thanks to JAX, we can _differentiate_ the model output w.r.t. model parameters. We can then use gradient-based optimization to find dynamical models that produce certain behaviors of interest. For example, in the tissue mechanics context you could ask: what mechanical actions do individual cells need to take so that the tissue as a whole takes on a certain shape?


## Developer Guide

This package is developed based on jupyter notebooks, which are converted into python modules using `nbdev`.
Take a look at `.github/workflows/copilot-instructions.md` for details.

### Install triangulax in Development mode


1. Clone the github repository

```sh
$ git clone https://github.com/nikolas-claussen/triangulax.git
```

2. Create a conda environment with all Python dependencies

```sh
$ conda env create -n triangulax -f triangulax.yml
$ conda activate triangulax
```

3. Install the `triangulax` package

```sh
# make sure triangulax package is installed in development mode
$ pip install -e .
```

4. If necessary, edit the package notebooks and export 
```sh
# make changes under nbs/ directory
# ...

# compile to have changes apply to triangulax
$ nbdev_prepare
```

## Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].
Jupyter notebooks with example simulations can be found in the `nbds/` folder.

[repo]: https://github.com/nikolas-claussen/triangulax
[docs]: https://nikolas-claussen.github.io/triangulax/

## Usage

- The `mesh` module provides a half-edge data structure for triangular meshes compatible with JAX.
- The `linops` module provides linear operators on meshes (gradient, Laplacian)
- The notebook `nbs/05_example_simulation.ipynb` showcases how to simulate mesh dynamics with `triangulax`

### Minimal example


In [None]:

import igl
import jax
import jax.numpy as jnp
from triangulax import mesh, linops

# load example mesh and convert to half-edge mesh

vertices, _, _, faces, _, _ = igl.readOBJ("test_meshes/disk.obj")
hemesh = mesh.HeMesh.from_triangles(vertices.shape[0], faces)

# with the half-edge mesh, you can carry out various operations, for example
# compute the coordination number by summing incoming half-edges per vertex

coord_number = jnp.zeros(hemesh.n_vertices)
coord_number = coord_number.at[hemesh.dest].add(jnp.ones(hemesh.n_hes))
print("Mean coordination number:", coord_number.mean())

# Let's define a simple geometric function and compute its gradient with JAX

def mean_voronoi_area(vertices, hemesh: mesh.HeMesh) -> float:
    """Compute the mean Voronoi area per vertex."""
    voronoi_areas = linops.get_cell_area(vertices, hemesh)
    return jnp.mean(voronoi_areas)

value, gradient = jax.value_and_grad(mean_voronoi_area)(vertices, hemesh)
print("Mean gradient norm:", jnp.linalg.norm(gradient, axis=1).mean())


Mean coordination number: 5.40458
Mean gradient norm: 0.009738781


  o flat_tri_ecmc
