From e05ca567f8c4634e932408494538e50228f56213 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Wed, 31 Jan 2024 22:48:59 -0500 Subject: [PATCH] feat: add function `integrate_orbit` to dynamics module (#123) * feat(dynamics): add integrate_orbit * feat(dynamics): add default integrator Signed-off-by: nstarman --- src/galax/dynamics/_dynamics/orbit.py | 108 ++++++++++++++++++++++++- src/galax/potential/_potential/base.py | 73 +++-------------- 2 files changed, 115 insertions(+), 66 deletions(-) diff --git a/src/galax/dynamics/_dynamics/orbit.py b/src/galax/dynamics/_dynamics/orbit.py index dc13a29f..5ad86e50 100644 --- a/src/galax/dynamics/_dynamics/orbit.py +++ b/src/galax/dynamics/_dynamics/orbit.py @@ -1,6 +1,6 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["Orbit"] +__all__ = ["Orbit", "integrate_orbit"] from dataclasses import replace from functools import partial @@ -9,9 +9,11 @@ import equinox as eqx import jax import jax.numpy as jnp +from astropy.units import Quantity +from galax.integrate import DiffraxIntegrator, Integrator from galax.potential._potential.base import AbstractPotentialBase -from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime +from galax.typing import BatchFloatScalar, BatchVec6, BroadBatchVec3, VecTime from galax.utils._shape import batched_shape from galax.utils.dataclasses import converter_float_array @@ -21,6 +23,8 @@ if TYPE_CHECKING: from typing import Self +############################################################################## + @final class Orbit(AbstractPhaseSpacePosition): @@ -109,3 +113,103 @@ def energy( The kinetic energy. """ return self.kinetic_energy() + self.potential_energy(potential) + + +############################################################################## + + +_default_integrator: Integrator = DiffraxIntegrator() + + +@partial(jax.jit, static_argnames=("integrator",)) +def integrate_orbit( + pot: AbstractPotentialBase, + w0: BatchVec6, + t: VecTime | Quantity, + *, + integrator: Integrator | None = None, +) -> Orbit: + """Integrate an orbit in potential. + + Parameters + ---------- + pot : :class:`~galax.potential.AbstractPotentialBase` + The potential in which to integrate the orbit. + w0 : Array[float, (*batch, 6)] + Initial position and velocity. + t: Array[float, (time,)] + Array of times at which to compute the orbit. The first element should + be the initial time and the last element should be the final time and + the array should be monotonically moving from the first to final time. + See the Examples section for options when constructing this argument. + + .. warning:: + + This is NOT the timesteps to use for integration, which are + controlled by the `integrator`; the default integrator + :class:`~galax.integrator.DiffraxIntegrator` uses adaptive + timesteps. + + integrator : :class:`~galax.integrate.Integrator`, keyword-only + Integrator to use. If `None`, the default integrator + :class:`~galax.integrator.DiffraxIntegrator` is used. + + Returns + ------- + orbit : :class:`~galax.dynamics.Orbit` + The integrated orbit evaluated at the given times. + + Examples + -------- + We start by integrating a single orbit in the potential of a point mass. + A few standard imports are needed: + + >>> import astropy.units as u + >>> import jax.experimental.array_api as xp # preferred over `jax.numpy` + >>> import galax.potential as gp + >>> from galax.units import galactic + + We can then create the point-mass' potential, with galactic units: + + >>> potential = gp.KeplerPotential(m=1e12 * u.Msun, units=galactic) + + We can then integrate an initial phase-space position in this potential to + get an orbit: + + >>> xv0 = xp.asarray([10., 0., 0., 0., 0.1, 0.]) # (x, v) galactic units + >>> ts = xp.linspace(0., 1000, 4) # (1 Gyr, 4 steps) + >>> orbit = potential.integrate_orbit(xv0, ts) + >>> orbit + Orbit( + q=f64[4,3], p=f64[4,3], t=f64[4], potential=KeplerPotential(...) + ) + + Note how there are 4 points in the orbit, corresponding to the 4 steps. + Changing the number of steps is easy: + + >>> ts = xp.linspace(0., 1000, 10) # (1 Gyr, 4 steps) + >>> orbit = potential.integrate_orbit(xv0, ts) + >>> orbit + Orbit( + q=f64[10,3], p=f64[10,3], t=f64[10], potential=KeplerPotential(...) + ) + + We can also integrate a batch of orbits at once: + + >>> xv0 = xp.asarray([[10., 0., 0., 0., 0.1, 0.], [10., 0., 0., 0., 0.2, 0.]]) + >>> orbit = potential.integrate_orbit(xv0, ts) + >>> orbit + Orbit( + q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...) + ) + """ + # Determine the integrator + # Reboot the integrator to avoid stateful issues + integrator = replace(integrator) if integrator is not None else _default_integrator + + # Integrate the orbit + ws = integrator(pot._integrator_F, w0, t) # noqa: SLF001 + # TODO: ꜛ reduce repeat dimensions of `time`. + + # Construct the orbit object + return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=pot) diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index 6c2c88ce..f5d6f3ae 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -1,10 +1,10 @@ __all__ = ["AbstractPotentialBase"] import abc -from dataclasses import KW_ONLY, fields, replace +from dataclasses import KW_ONLY, fields from functools import partial from types import MappingProxyType -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import equinox as eqx import jax @@ -14,10 +14,8 @@ from astropy.coordinates import BaseRepresentation from astropy.units import Quantity from jax import grad, hessian, jacfwd -from jaxtyping import Array, Float from galax.integrate._api import Integrator -from galax.integrate._builtin import DiffraxIntegrator from galax.potential._potential.param.attr import ParametersAttribute from galax.potential._potential.param.utils import all_parameters from galax.typing import ( @@ -31,6 +29,7 @@ Matrix33, Vec3, Vec6, + VecTime, ) from galax.units import UnitSystem, dimensionless from galax.utils._jax import vectorize_method @@ -43,9 +42,6 @@ from galax.dynamics._dynamics.orbit import Orbit -default_integrator: Integrator = DiffraxIntegrator() - - class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta, strict=True): # type: ignore[misc] """Abstract Potential Class.""" @@ -333,12 +329,14 @@ def _integrator_F(self, t: FloatScalar, w: Vec6, args: tuple[Any, ...]) -> Vec6: def integrate_orbit( self, w0: BatchVec6, - t: Float[Array, "time"] | Quantity, + t: VecTime | Quantity, *, integrator: Integrator | None = None, ) -> "Orbit": """Integrate an orbit in the potential. + See :func:`~galax.dynamics.integrate_orbit` for more details. + Parameters ---------- w0 : Array[float, (6,)] @@ -350,10 +348,6 @@ def integrate_orbit( final time. See the Examples section for options when constructing this argument. - .. note:: - - To integrate backwards in time, ... - .. warning:: This is NOT the timesteps to use for integration, which are @@ -361,7 +355,7 @@ def integrate_orbit( :class:`~galax.integrator.DiffraxIntegrator` uses adaptive timesteps. - integrator : AbstractIntegrator, keyword-only + integrator : AbstractIntegrator | None, keyword-only Integrator to use. If `None`, the default integrator :class:`~galax.integrator.DiffraxIntegrator` is used. @@ -369,56 +363,7 @@ def integrate_orbit( ------- orbit : Orbit The integrated orbit evaluated at the given times. - - Examples - -------- - We start by integrating a single orbit in the potential of a point mass. - A few standard imports are needed: - - >>> import astropy.units as u - >>> import jax.experimental.array_api as xp # preferred over `jax.numpy` - >>> import galax.potential as gp - >>> from galax.units import galactic - - We can then create the point-mass' potential, with galactic units: - - >>> potential = gp.KeplerPotential(m=1e12 * u.Msun, units=galactic) - - We can then integrate an initial phase-space position in this potential - to get an orbit: - - >>> xv0 = xp.asarray([10., 0., 0., 0., 0.1, 0.]) # (x, v) galactic units - >>> ts = xp.linspace(0., 1000, 4) # (1 Gyr, 4 steps) - >>> orbit = potential.integrate_orbit(xv0, ts) - >>> orbit - Orbit( - q=f64[4,3], p=f64[4,3], t=f64[4], potential=KeplerPotential(...) - ) - - Note how there are 4 points in the orbit, corresponding to the 4 steps. - Changing the number of steps is easy: - - >>> ts = xp.linspace(0., 1000, 10) # (1 Gyr, 4 steps) - >>> orbit = potential.integrate_orbit(xv0, ts) - >>> orbit - Orbit( - q=f64[10,3], p=f64[10,3], t=f64[10], potential=KeplerPotential(...) - ) - - We can also integrate a batch of orbits at once: - - >>> xv0 = xp.asarray([[10., 0., 0., 0., 0.1, 0.], [10., 0., 0., 0., 0.2, 0.]]) - >>> orbit = potential.integrate_orbit(xv0, ts) - >>> orbit - Orbit( - q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...) - ) """ - # TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line - from galax.dynamics._dynamics.orbit import Orbit - - integrator_ = default_integrator if integrator is None else replace(integrator) + from galax.dynamics._dynamics.orbit import integrate_orbit - ws = integrator_(self._integrator_F, w0, t) - # TODO: ꜛ reduce repeat dimensions of `time`. - return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=self) + return cast("Orbit", integrate_orbit(self, w0, t, integrator=integrator))