## Coding in JAX

`triangulax` aims to create a triangulation datastructure compatible with the JAX library for automatic differentiation and numerical computing (see [JAX- the sharp bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)). What does this mean in practice?

1. Use `jnp` (=`jax.numpy`) instead of `numpy`.
2. Use a _functional programming_ paradigm (pure functions, no side-effects). Avoid dynamically changing array shapes. For example, insead of in-place array modifications, use JAX's `x = x.at[idx].set(y)`.
3. Use JAX idioms for [control flow](https://docs.jax.dev/en/latest/control-flow.html)
4. _Register_ any new classes, so JAX knows how to handle them during gradient-computation and just-in-time compilation. See [here](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#using-jax-jit-with-class-methods) and [here](https://docs.jax.dev/en/latest/custom_pytrees.html).

To provide type signatures for all functions (what are the inputs? what do the array dimensions mean?), we use  [jaxtyping](https://docs.kidger.site/jaxtyping). Lateron, we will alsop use the `equinox` library, which adds a few useful tools to JAX.

### PyTrees

JAX supports not only arrays as inputs/outputs/intermediate variables, but also [pytrees](https://docs.jax.dev/en/latest/pytrees.html). Pytrees are nested structures (dicts, lists-of-lists, etc) whose leaves are "elementary" objects like arrays. Fortunately, our triangulation datastructure classes are already a lot like a pytree - it is a collection of arrays. For JAX to understand this, we need to register our classes as a [custom pytree node](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes) via `jax.tree_util.register_dataclass`.

Sidenote: Neural networks, in libraries like Flax or Equinox, are basically very similar. They are dataclass-like classes which hold all the arrays associated with a NN (the different weights, and maybe some parameters) with class methods like `__call__` specifying the forward pass through the NN. Equinox automatically registers your NN as a pytree by inheriting from the `equinox.Module` class.

### [Control flow](https://docs.jax.dev/en/latest/control-flow.html)

For just-in-time compilation, JAX distinguishes two types of variables: dynamic and static. Control flow cannot depend on the _value_ of dynamic variables, only on their shape.

Upshots:
1. replace `if` with `jax.lax.cond` / `jnp.where` (full autodiff compatible), and `while` with `jax.lax.while_loop` (forward autodiff only).
2. mark variables which are not going to change during simulation as static.

### Static array shapes

JAX works best if the _shapes_ of arrays do not change during the computation. For this reason, we (first) focus on triangulations where the number of vertices does not change. 
Topological modifications (like edge flips) are nevertheless possible, as long as they do not change the number of mesh elements (vertices, edges, and faces). 

### Batching

In simulations, we may want to "batch" over several initial conditions/random seeds/etc (analogous to batching over training data in normal ML). In JAX, one can efficiently and concisely vectorize operations over such "batch axes" with `jax.vmap`. 

To batch over our custom data structures, we need to pull a small trick - convert a list of instances into a single mesh with a batch axis for the various arrays. Luckily, this can be [done using JAX's pytree tools](https://stackoverflow.com/questions/79123001/storing-and-jax-vmap-over-pytrees).

## Simulation loops with `jax.lax.scan`

In simulations, we generally start with an initial state (call it `init`), do a series of timesteps (via a function `make_step(state)`), and record some "measurement" at each timestep (via a `measure(state)` function). As a result, we get a timeseries of measurements, and the final simulation state. In normal python, you would do that with a `for` loop. When working with JAX, we need to [replace control-flow operations like `for` with their JAX pendant](https://docs.jax.dev/en/latest/control-flow.html). For for loops, this is `jax.lax.scan(f, init, xs)`, which is equivalent to the python code
```python
def scan(f, init, xs):
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
```
In our pattern, `xs` is the vector of time-points `timepoints`, and the "scanning-function" `f` is generally comprised of two parts, a time-step and a measurement/logging step (above, we logged energy and T1 count):
```python
def f(carry, t):
    new_state = make_step(carry, t)
    measurements = measure(new_state)
    return new_state, measurements
```

The `carry` variable contains all information about the state of the simulation. Typically, `carry` is also composed of multiple pieces (the, the physical state `physical_state`, as well as ancilliary variables like the ODE solver state `solver_state`). To keep things organized, it can make sense to define dataclasses for the simulation state and the measurements, like this (schematic) example:

```python
@jax.tree_util.register_dataclass
@dataclass
class SimState:
    physical_state: jax.Array
    solver_state: dict  # or another PyTree
    current_time: jax.Array

@jax.tree_util.register_dataclass
@dataclass
class Log:
    energy: float

def scan_function(carry: SimState, next_time: jax.Array) -> tuple[SimState, Log]:
    physical_state, solver_state = make_step(carry.physical_state, carry.solver_state,
                                             carry.current_time, next_time)
    log = Log(energy=compute_energy(physical_state))
    return SimState(physical_state, solver_state), log

timepoints = jnp.arange(t0, t1, dt)
init = ... # define initial condition
final_state, measurements = jax.lax.scan(scan_function, init, timepoints) 
```