# Jaxoplanet from scratch

Inspired by the [autodidax tutorial](https://jax.readthedocs.io/en/latest/autodidax.html) from the JAX documentation, in this tutorial we work through implementing some of the core `jaxoplanet` functionality from scratch, to demonstrate and discuss the choices made within the depths of the codebase.

## Solving Kepler's equation

One core piece of infrastructure provided by `jaxoplanet` is a function to solve Kepler's equation

$$
M = E - e \sin(E)
$$

for the eccentric anomaly $E$ as a function of the eccentricity $e$ and mean anomaly $M$.
There is a lot of literature dedicated to solving this equation efficiently and robustly, and we won't get into all the details here, but there are a few points we should highlight:

1. The methods that are most commonly used in astrophysics to solve this equation are all iterative, using some sort of root finding scheme. While these methods can work well, they tend to be less computationally efficient than non-iterative approaches. Even more importantly for our purposes, non-iterative methods are better suited to massively parallel compute architectures like GPUs. These non-iterative methods typically have a two step form: (i) make a good initial guess ("starter") for $E$, then (ii) use a high order root finding update to refine this estimate. 

2. In most Python codes, the Kepler solver is offloaded to a compiled library, but we will find that we can get comparable performance just using JAX, and relying on its JIT compilation to accelerate the computation.

With these points in mind, we can implement the solver that is included with `jaxoplanet`, which is based on [Markley (1995)](https://ui.adsabs.harvard.edu/abs/1995CeMDA..63..101M/abstract).
First we implement a "starter" function which uses Markley's approximation to estimate $E$ as a function of $M$ and $e$:

In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


def kepler_starter(mean_anom, ecc):
    ome = 1 - ecc
    M2 = jnp.square(mean_anom)
    alpha = 3 * jnp.pi / (jnp.pi - 6 / jnp.pi)
    alpha += 1.6 / (jnp.pi - 6 / jnp.pi) * (jnp.pi - mean_anom) / (1 + ecc)
    d = 3 * ome + alpha * ecc
    alphad = alpha * d
    r = (3 * alphad * (d - ome) + M2) * mean_anom
    q = 2 * alphad * ome - M2
    q2 = jnp.square(q)
    w = (jnp.abs(r) + jnp.sqrt(q2 * q + r * r)) ** (2.0 / 3)
    return (2 * r * w / (jnp.square(w) + w * q + q2) + mean_anom) / d

Then we implement a third order Householder update to refine this estimate:

In [None]:
def kepler_refiner(mean_anom, ecc, ecc_anom):
    ome = 1 - ecc
    sE = ecc_anom - jnp.sin(ecc_anom)
    cE = 1 - jnp.cos(ecc_anom)

    f_0 = ecc * sE + ecc_anom * ome - mean_anom
    f_1 = ecc * cE + ome
    f_2 = ecc * (ecc_anom - sE)
    f_3 = 1 - f_1
    d_3 = -f_0 / (f_1 - 0.5 * f_0 * f_2 / f_1)
    d_4 = -f_0 / (f_1 + 0.5 * d_3 * f_2 + (d_3 * d_3) * f_3 / 6)
    d_42 = d_4 * d_4
    dE = -f_0 / (f_1 + 0.5 * d_4 * f_2 + d_4 * d_4 * f_3 / 6 - d_42 * d_4 * f_2 / 24)

    return ecc_anom + dE

Putting these together, we can construct a solver function which includes some extra bookkeeping to handle the range reduction of the inputs:

In [None]:
@jax.jit
def kepler_solver_impl(mean_anom, ecc):
    mean_anom = mean_anom % (2 * jnp.pi)

    # We restrict to the range [0, pi)
    high = mean_anom > jnp.pi
    mean_anom = jnp.where(high, 2 * jnp.pi - mean_anom, mean_anom)

    # Solve
    ecc_anom = kepler_starter(mean_anom, ecc)
    ecc_anom = kepler_refiner(mean_anom, ecc, ecc_anom)

    # Re-wrap back into the full range
    ecc_anom = jnp.where(high, 2 * jnp.pi - ecc_anom, ecc_anom)

    return ecc_anom

In [None]:
import matplotlib.pyplot as plt

ecc = 0.5
true_ecc_anom = jnp.linspace(0, 2 * jnp.pi, 50_000)[:-1]
mean_anom = true_ecc_anom - ecc * jnp.sin(true_ecc_anom)

calc_acc_anom = kepler_solver_impl(mean_anom, ecc)

plt.plot(true_ecc_anom, calc_acc_anom - true_ecc_anom, "k")
plt.axhline(0, color="k")
plt.xlabel("eccentric anomaly")
plt.ylabel("error from Kepler solver");

In [None]:
fig, axes = plt.subplots(2, 1, sharex=True)

ax = axes[0]
ax.plot(
    mean_anom, jax.vmap(jax.grad(kepler_solver_impl), in_axes=(0, None))(mean_anom, ecc)
)
ax.set_ylabel("dE / dM")

ax = axes[1]
ax.plot(
    mean_anom,
    jax.vmap(jax.grad(kepler_solver_impl, argnums=1), in_axes=(0, None))(
        mean_anom, ecc
    ),
)
ax.set_xlabel("mean anomaly")
ax.set_ylabel("dE / de");

$$
\mathrm{d}M = \mathrm{d}E (1 - e \cos E) - \mathrm{d}e \sin E
$$

$$
\frac{\partial E}{\partial M} = \frac{1}{1 - e \cos E}
$$

$$
\frac{\partial E}{\partial e} = \frac{\sin E}{1 - e \cos E}
$$

In [None]:
@jax.custom_jvp
def kepler_solver(mean_anom, ecc):
    return kepler_solver_impl(mean_anom, ecc)


@kepler_solver.defjvp
def kepler_solver_jvp(primals, tangents):
    mean_anom, ecc = primals
    d_mean_anom, d_ecc = tangents

    ecc_anom = kepler_solver(mean_anom, ecc)
    dEdM = 1 / (1 - ecc * jnp.cos(ecc_anom))
    dEde = jnp.sin(ecc_anom) * dEdM

    out_tangents = []
    if type(d_mean_anom) is not jax.interpreters.ad.Zero:
        out_tangents.append(dEdM * d_mean_anom)
    if type(d_ecc) is not jax.interpreters.ad.Zero:
        out_tangents.append(dEde * d_ecc)

    print(d_mean_anom, d_ecc)
    if out_tangents:
        out_tangent = sum(out_tangents)
    else:
        out_tangent = jax.interpreters.ad.Zero.from_value(ecc_anom)

    return ecc_anom, out_tangent

In [None]:
jax.vmap(jax.grad(kepler_solver, argnums=1), in_axes=(0, None))(mean_anom, ecc)