# triangulax

> JAX-compatible triangular meshes and triangular-mesh-based simulations


## Overview

This python package provides data-structures for triangular meshes and geometry processing tools based on JAX and fully compatible with JAX's just-in-time compilation and automatic differentiation.

### Use cases

Triangular meshes are ubiquitous in computer graphics and in scientific computing. The tools provided by `triangulax` makes it easy to implement custom geometry processing operations (and in Python, rather than C++). `triangulax` complements libraries like the excellent [`libigl` python bindings](https://libigl.github.io/libigl-python-bindings/) focused on providing "ready made" geometry processing tools.

A second use case is simulating for 2D tissue mechanics ([Active tension networks](https://www.pnas.org/doi/10.1073/pnas.2321928121), [area-perimeter vertex model](https://journals.aps.org/prx/abstract/10.1103/PhysRevX.6.021011), [active foams](https://www.nature.com/articles/s41567-021-01215-1)). 2D cell tilings are conveniently represented by their dual triangular network.
`triangulax` makes it (relatively) easy to implement custom models within a single libary. JAX makes it easy to compute mechanical forces from arbitrary energy functions.

### Automatic differentiation

The main feature of `triangulax` is compatibility with automatic differentiation. This enables computation of gradients of any mesh-based function. Most tools are also compatible with JAX's JIT-compilation, delivering high performance in high-level python.

For example, consider:

1. Flattening or deforming 3D models (computer graphics)
2. Mechanics of thin plates or membranes (mechanics) 
3. Cell resolved tissue simulations

Such tasks often feature some mesh-based "energy" (like the [Dirichlet variational functional](https://multires.caltech.edu/pubs/ConfEquiv.pdf), the [Helfrich elastic energy](https://en.wikipedia.org/wiki/Elasticity_of_cell_membranes), or the [Dirichlet functional](https://multires.caltech.edu/pubs/ConfEquiv.pdf), or the [area-perimeter energy](https://journals.aps.org/prx/abstract/10.1103/PhysRevX.6.021011), respectively). Automatic differentiation with JAX makes is trivial to compute gradients. 
This makes it easy to optimize energies or to simulate forces.

#### Gradient-based "meta-optimization" and inverse problems

Automatic differentiation goes further: once we have a dynamical model for triangular meshes, we can _differentiate_ the model output w.r.t. its parameters. This means one can apply gradient-based optimization to _inverse problems_. For example, in the tissue mechanics context: what mechanical actions do individual cells need to take so that the tissue as a whole takes on a certain shape?

## Developer guide and installation instructions

This package is developed based on jupyter notebooks, which are converted into python modules using `nbdev`.
Take a look at `.github/workflows/copilot-instructions.md` for details.

### Install triangulax in Development mode


1. Clone the github repository

```sh
$ git clone https://github.com/nikolas-claussen/triangulax.git
```

2. Create a conda environment with all Python dependencies

```sh
$ conda env create -n triangulax -f triangulax.yml
$ conda activate triangulax
```

3. Install the `triangulax` package

```sh
# make sure triangulax package is installed in development mode
$ pip install -e .
```

4. If necessary, edit the package notebooks and export 
```sh
# make changes under nbs/ directory
# ...

# compile to have changes apply to triangulax
$ nbdev_prepare
```

## Documentation

Documentation can be found hosted on this GitHub [repository][repo]'s [pages][docs].
Jupyter notebooks with example simulations can be found in the `nbds/` folder.

[repo]: https://github.com/nikolas-claussen/triangulax
[docs]: https://nikolas-claussen.github.io/triangulax/

## Usage

`triangulax` comprises the following modules:

- `triangular`: input/output for triangular meshes
- `trigonometry`: trigonometry
- `mesh`: a half-edge data structure for triangular meshes compatible with JAX.
- `topology`: topological modifications (flip, collapse, and split)
- `adjacency`, `geometry`, `linops`: geometry processing tools
- Notebooks `nbs/08_example_simulation.ipynb`, `nbs/09_self_propelled_Voronoi.ipynb`: examples for simulating mesh dynamics

### Minimal example

In [None]:
import igl
import jax
import jax.numpy as jnp
from triangulax import mesh, geometry

# load example mesh and convert to half-edge mesh

vertices, _, _, faces, _, _ = igl.readOBJ("test_meshes/disk.obj")
hemesh = mesh.HeMesh.from_triangles(vertices.shape[0], faces)

# with the half-edge mesh, you can carry out various operations, for example
# compute the coordination number by summing incoming half-edges per vertex

coord_number = jnp.zeros(hemesh.n_vertices)
coord_number = coord_number.at[hemesh.dest].add(jnp.ones(hemesh.n_hes))
print("Mean coordination number:", coord_number.mean())

# Let's define a simple geometric function and compute its gradient with JAX

def mean_voronoi_area(vertices, hemesh: mesh.HeMesh) -> float:
    """Compute the mean Voronoi area per vertex."""
    voronoi_areas = geometry.get_voronoi_areas(vertices, hemesh)
    return jnp.mean(voronoi_areas)

value, gradient = jax.value_and_grad(mean_voronoi_area)(vertices, hemesh)
print("Mean gradient norm:", jnp.linalg.norm(gradient, axis=1).mean())

  o flat_tri_ecmc


Mean coordination number: 5.40458
Mean gradient norm: 0.00036383414
