In [1]:
import jax.numpy as jnp
import jax
import typing as tp
import utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

In [2]:
class Body(tp.NamedTuple):
    x: jax.Array
    v: jax.Array
    m: jax.Array


# RNG keys
seed = jax.random.PRNGKey(42)
kx, kv, km = jax.random.split(seed, 3)

# Simulation Parameters
N_BODIES = 3
T0 = 0.0
months = 15
Tf = 3600.0 * 30.5 * months  # in seconds
DT = 1.0  # 1 second
T = jnp.arange(T0, Tf, DT)
G = jnp.array(6.67408e-11, dtype=jnp.float32)

# Initial Values
X0 = jax.random.uniform(kx, shape=(N_BODIES, 3), minval=-748e7, maxval=748e7)
V0 = jax.random.uniform(kv, shape=(N_BODIES, 3), minval=-1e5, maxval=1e5)
M = jax.random.uniform(km, shape=(N_BODIES, 1), minval=5.972e27, maxval=1.898e30)
M = M[:, 0]
# ignore z-axis
X0 = X0.at[:, 2].set(0.0)
V0 = V0.at[:, 2].set(0.0)

# system state
Y: Body = Body(X0, V0, M)

In [3]:
def gravitational_force(a: Body, b: Body) -> jax.Array:
    radius3 = utils.safe_norm(b.x, a.x) ** 3
    f = G * b.m * (b.x - a.x) / radius3
    # fill nan values with 0s
    return jnp.nan_to_num(f)


def gravitational_energy(a: Body, b: Body) -> jax.Array:
    r = utils.safe_norm(a.x, b.x)
    energy = -G * b.m / r
    energy = jnp.where(jnp.allclose(a.x, b.x), 0.0, energy)
    return energy

In [4]:
a = Body(X0, V0, M)

f1 = utils.map_product(gravitational_force)(a, a).sum(axis=1)
f2 = -jax.grad(lambda a, b: utils.map_product(gravitational_energy)(a, b).sum())(a, a).x

print(f"{a=}")
print(f"{f1=}")
print(f"{f2=}")

a=Body(x=Array([[-4.59295233e+08,  4.53854906e+09,  0.00000000e+00],
       [-5.85289399e+09,  3.37854154e+09,  0.00000000e+00],
       [ 5.82123955e+09,  4.86107921e+09,  0.00000000e+00]],      dtype=float64), v=Array([[ 81284.96237573, -82094.51176463,      0.        ],
       [-29689.1169026 ,  44408.16038729,      0.        ],
       [-67493.46596226,  36168.10304887,      0.        ]],      dtype=float64), m=Array([1.14009763e+30, 1.54176942e+30, 1.12464207e+30], dtype=float64))
f1=Array([[-1.4098114 , -0.61351644,  0.        ],
       [ 2.98179572,  0.59393946,  0.        ],
       [-2.65855148, -0.19228257,  0.        ]], dtype=float64)
f2=Array([[-1.4098114 , -0.61351644,  0.        ],
       [ 2.98179572,  0.59393946,  0.        ],
       [-2.65855148, -0.19228257,  0.        ]], dtype=float64)


In [5]:
f1.shape

(3, 3)

In [6]:
f2.shape

(3, 3)