# Units

All of the user-facing interfaces in `jaxoplanet` support units using the [`jpu` package](https://github.com/dfm/jpu).
As far as we know, `jaxoplanet` is the only user of this unit system, and this API should be considered somewhat experimental, but we think that there are benefits to being unambiguous about units where possible.
So, here we provide an overview of how `jpu` works and how it interacts with `jaxoplanet`.

For technical reasons, the `jpu` package is built on top of [`Pint`](https://pint.readthedocs.io), rather than [`astropy.units`](https://docs.astropy.org/en/stable/units/), so the interface might be unfamiliar to some users, but we hope that we can provide enough information here to get you started.
Please also refer to the [`jpu`](https://github.com/dfm/jpu) and [`Pint`](https://pint.readthedocs.io) documentation for more details.

To start, we import the "unit registry" from `jaxoplanet`, which will provide the building blocks for interacting with our unit system:

In [None]:
from jaxoplanet.units import unit_registry as u

(1.0 * u.au).to(u.mile)

`jpu` also provides its own `UnitRegistry`, but the `jaxoplanet` registry includes some useful astronomy specific definitions that aren't supported natively by `Pint` or `jpu`.

Then, using this registry, we can write JAX code that handles unit conversions:

In [None]:
import jax
import jax.numpy as jnp


@jax.jit
def add_two_lengths(x, y):
    return x + y


add_two_lengths(1.5 * u.m, jnp.linspace(10.0, 50.0, 3) * u.cm)

These `Quantity` objects should mostly play well with JAX's programming models and function transformations, but there are a couple of subtleties that should be emphasized.

## Using mathematical functions

First, the usual `jax.numpy` functions don't work as expected with quantities:

In [None]:
try:
    jnp.cos(45.0 * u.degrees)
except TypeError as e:
    print(e)

As you can see, calling a `jax.numpy` function with a `Quantity` as an argument will throw a `TypeError` because `jax.numpy` functions strictly require JAX arrays as input.
Unlike `numpy`, `jax.numpy` doesn't (currently) support array dispatching on custom types.
Instead, you'll need to either use the interface defined in `jpu.numpy`:

In [None]:
import jpu.numpy as jnpu

jnpu.cos(45.0 * u.degrees)

Or manually extract the "magnitude" of your `Quantity`:

In [None]:
jnp.cos((45.0 * u.degrees).to(u.radian).magnitude)

## Gradients

Most JAX function transformations (`jax.jit`, `jax.vmap`, etc.) work properly with `Quantity` inputs, 

One technical note that is worth mentioning here is that 

In [None]:
jax.make_jaxpr(lambda x, y: x + y)(1.5 * u.m, 50.0 * u.cm)

which shows how the unit conversion is handled within jaxprs.