In [4]:
import jax
import jax.numpy as jnp
import numpy as np
import astropy.constants as c
from astropy.units import Quantity
import astropy.units as u
from jax import grad

In [42]:
# nfw potential

vel_unit = u.km / u.s

def nfw_potential(
        q,
        mass: Quantity[u.Msun] = 1e12 * u.Msun,
        rs: Quantity[u.kpc] = 25 * u.kpc) -> Quantity[(u.km/u.s)**2]:
    
    r = jnp.sqrt(q[0]**2 + q[1]**2 + q[2]**2)
    rho0 = mass / (4 * np.pi * rs**3 * (1 + rs/r))
    return (-4 * np.pi * c.G * rho0 * rs**3 * (np.log(1 + r/rs)) / r).to(vel_unit**2).value

# nfw potential but jaxified

def nfw_potential_jax(
        r,
        mass=1e12 * u.Msun,
        rs=25 * u.kpc):
    
    # r = jnp.sqrt(q[0]**2 + q[1]**2 + q[2]**2)
    rho0 = mass/ (4 * jnp.pi * rs**3 * (1 + rs/r))
    return (-4 * jnp.pi * c.G * rho0 * rs**3 * (jnp.log(1 + r/rs)) / r).to(vel_unit**2)

In [43]:
grad_pot = grad(nfw_potential_jax)

In [44]:
grad_pot(25 * u.kpc)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [20]:
grad_test = grad(jnp.sqrt)

In [36]:
def g(x):
    return jnp.sqrt(3. * u.km /(x**2 + 1))

In [37]:
grad_test = grad(g)

In [40]:
g(0.00)

Array(1.7320508, dtype=float32)