Skip to content

Commit

Permalink
feat: add function integrate_orbit to dynamics module (GalacticDyna…
Browse files Browse the repository at this point in the history
…mics#123)

* feat(dynamics): add integrate_orbit
* feat(dynamics): add default integrator

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Feb 1, 2024
1 parent 1e1f081 commit e05ca56
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 66 deletions.
108 changes: 106 additions & 2 deletions src/galax/dynamics/_dynamics/orbit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["Orbit"]
__all__ = ["Orbit", "integrate_orbit"]

from dataclasses import replace
from functools import partial
Expand All @@ -9,9 +9,11 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from astropy.units import Quantity

from galax.integrate import DiffraxIntegrator, Integrator
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime
from galax.typing import BatchFloatScalar, BatchVec6, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array

Expand All @@ -21,6 +23,8 @@
if TYPE_CHECKING:
from typing import Self

##############################################################################


@final
class Orbit(AbstractPhaseSpacePosition):
Expand Down Expand Up @@ -109,3 +113,103 @@ def energy(
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)


##############################################################################


_default_integrator: Integrator = DiffraxIntegrator()


@partial(jax.jit, static_argnames=("integrator",))
def integrate_orbit(
pot: AbstractPotentialBase,
w0: BatchVec6,
t: VecTime | Quantity,
*,
integrator: Integrator | None = None,
) -> Orbit:
"""Integrate an orbit in potential.
Parameters
----------
pot : :class:`~galax.potential.AbstractPotentialBase`
The potential in which to integrate the orbit.
w0 : Array[float, (*batch, 6)]
Initial position and velocity.
t: Array[float, (time,)]
Array of times at which to compute the orbit. The first element should
be the initial time and the last element should be the final time and
the array should be monotonically moving from the first to final time.
See the Examples section for options when constructing this argument.
.. warning::
This is NOT the timesteps to use for integration, which are
controlled by the `integrator`; the default integrator
:class:`~galax.integrator.DiffraxIntegrator` uses adaptive
timesteps.
integrator : :class:`~galax.integrate.Integrator`, keyword-only
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.
Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
The integrated orbit evaluated at the given times.
Examples
--------
We start by integrating a single orbit in the potential of a point mass.
A few standard imports are needed:
>>> import astropy.units as u
>>> import jax.experimental.array_api as xp # preferred over `jax.numpy`
>>> import galax.potential as gp
>>> from galax.units import galactic
We can then create the point-mass' potential, with galactic units:
>>> potential = gp.KeplerPotential(m=1e12 * u.Msun, units=galactic)
We can then integrate an initial phase-space position in this potential to
get an orbit:
>>> xv0 = xp.asarray([10., 0., 0., 0., 0.1, 0.]) # (x, v) galactic units
>>> ts = xp.linspace(0., 1000, 4) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[4,3], p=f64[4,3], t=f64[4], potential=KeplerPotential(...)
)
Note how there are 4 points in the orbit, corresponding to the 4 steps.
Changing the number of steps is easy:
>>> ts = xp.linspace(0., 1000, 10) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[10,3], p=f64[10,3], t=f64[10], potential=KeplerPotential(...)
)
We can also integrate a batch of orbits at once:
>>> xv0 = xp.asarray([[10., 0., 0., 0., 0.1, 0.], [10., 0., 0., 0., 0.2, 0.]])
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...)
)
"""
# Determine the integrator
# Reboot the integrator to avoid stateful issues
integrator = replace(integrator) if integrator is not None else _default_integrator

# Integrate the orbit
ws = integrator(pot._integrator_F, w0, t) # noqa: SLF001
# TODO: ꜛ reduce repeat dimensions of `time`.

# Construct the orbit object
return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=pot)
73 changes: 9 additions & 64 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__all__ = ["AbstractPotentialBase"]

import abc
from dataclasses import KW_ONLY, fields, replace
from dataclasses import KW_ONLY, fields
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import equinox as eqx
import jax
Expand All @@ -14,10 +14,8 @@
from astropy.coordinates import BaseRepresentation
from astropy.units import Quantity
from jax import grad, hessian, jacfwd
from jaxtyping import Array, Float

from galax.integrate._api import Integrator
from galax.integrate._builtin import DiffraxIntegrator
from galax.potential._potential.param.attr import ParametersAttribute
from galax.potential._potential.param.utils import all_parameters
from galax.typing import (
Expand All @@ -31,6 +29,7 @@
Matrix33,
Vec3,
Vec6,
VecTime,
)
from galax.units import UnitSystem, dimensionless
from galax.utils._jax import vectorize_method
Expand All @@ -43,9 +42,6 @@
from galax.dynamics._dynamics.orbit import Orbit


default_integrator: Integrator = DiffraxIntegrator()


class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta, strict=True): # type: ignore[misc]
"""Abstract Potential Class."""

Expand Down Expand Up @@ -333,12 +329,14 @@ def _integrator_F(self, t: FloatScalar, w: Vec6, args: tuple[Any, ...]) -> Vec6:
def integrate_orbit(
self,
w0: BatchVec6,
t: Float[Array, "time"] | Quantity,
t: VecTime | Quantity,
*,
integrator: Integrator | None = None,
) -> "Orbit":
"""Integrate an orbit in the potential.
See :func:`~galax.dynamics.integrate_orbit` for more details.
Parameters
----------
w0 : Array[float, (6,)]
Expand All @@ -350,75 +348,22 @@ def integrate_orbit(
final time. See the Examples section for options when constructing
this argument.
.. note::
To integrate backwards in time, ...
.. warning::
This is NOT the timesteps to use for integration, which are
controlled by the `integrator`; the default integrator
:class:`~galax.integrator.DiffraxIntegrator` uses adaptive
timesteps.
integrator : AbstractIntegrator, keyword-only
integrator : AbstractIntegrator | None, keyword-only
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.
Returns
-------
orbit : Orbit
The integrated orbit evaluated at the given times.
Examples
--------
We start by integrating a single orbit in the potential of a point mass.
A few standard imports are needed:
>>> import astropy.units as u
>>> import jax.experimental.array_api as xp # preferred over `jax.numpy`
>>> import galax.potential as gp
>>> from galax.units import galactic
We can then create the point-mass' potential, with galactic units:
>>> potential = gp.KeplerPotential(m=1e12 * u.Msun, units=galactic)
We can then integrate an initial phase-space position in this potential
to get an orbit:
>>> xv0 = xp.asarray([10., 0., 0., 0., 0.1, 0.]) # (x, v) galactic units
>>> ts = xp.linspace(0., 1000, 4) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[4,3], p=f64[4,3], t=f64[4], potential=KeplerPotential(...)
)
Note how there are 4 points in the orbit, corresponding to the 4 steps.
Changing the number of steps is easy:
>>> ts = xp.linspace(0., 1000, 10) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[10,3], p=f64[10,3], t=f64[10], potential=KeplerPotential(...)
)
We can also integrate a batch of orbits at once:
>>> xv0 = xp.asarray([[10., 0., 0., 0., 0.1, 0.], [10., 0., 0., 0., 0.2, 0.]])
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...)
)
"""
# TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line
from galax.dynamics._dynamics.orbit import Orbit

integrator_ = default_integrator if integrator is None else replace(integrator)
from galax.dynamics._dynamics.orbit import integrate_orbit

ws = integrator_(self._integrator_F, w0, t)
# TODO: ꜛ reduce repeat dimensions of `time`.
return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=self)
return cast("Orbit", integrate_orbit(self, w0, t, integrator=integrator))

0 comments on commit e05ca56

Please sign in to comment.