# Units

In order to deal with physical units, *jaxoplanet* uses [jpu](https://github.com/dfm/jpu), a Python package that provides an interface between [JAX](https://jax.readthedocs.io) and [Pint](https://pint.readthedocs.io) to allow JAX to support operations with units.

Here is an example from their documentatiton:

In [16]:
import jax
import jax.numpy as jnp
import jpu

u = jpu.UnitRegistry()


@jax.jit
def add_two_lengths(a, b):
    return a + b


result = add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)

The `result` has a value with units attached to it and is of type

In [18]:
type(result)

pint.JpuQuantity

To get the value of the quantity, one can call

In [35]:
result.magnitude

Array([3.045, 3.012, 3.039], dtype=float32)

And to convert to a different unit

In [36]:
result.to(u.cm)

# or

result.to("cm")

0,1
Magnitude,[304.5 301.2 303.9]
Units,centimeter


The particularity of the `jpuQuantity` is that it cannot be used direcly within a function

In [37]:
try:
    jnp.power(result, 2)
except Exception as e:
    print(e)

power requires ndarray or scalar arguments, got <class 'pint.JpuQuantity'> at position 0.


Hence, if you ever need to apply a `jax.numpy` function on a `jpuQuantity`, you can use `jpu.numpy` module instead

In [38]:
import jpu.numpy as jnpu

jnpu.power(result, 2)

0,1
Magnitude,[9.272025 9.0721445 9.235521 ]
Units,meter2
