From 191595db31bd8aa7655965cbb95dd928f29691c7 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Wed, 12 Jun 2024 16:20:30 -0400 Subject: [PATCH] refactor: consolidate integrate (#346) Signed-off-by: nstarman --- src/galax/coordinates/_psp/interp.py | 4 + .../dynamics/_dynamics/integrate/_base.py | 109 ++++- .../dynamics/_dynamics/integrate/_builtin.py | 446 +++++++----------- .../_dynamics/mockstream/df/_fardal15.py | 30 +- 4 files changed, 307 insertions(+), 282 deletions(-) diff --git a/src/galax/coordinates/_psp/interp.py b/src/galax/coordinates/_psp/interp.py index 46581b5b..6522401a 100644 --- a/src/galax/coordinates/_psp/interp.py +++ b/src/galax/coordinates/_psp/interp.py @@ -25,6 +25,10 @@ class PhaseSpacePositionInterpolant(Protocol): units: AbstractUnitSystem """The unit system for the interpolation.""" + added_ndim: int + """The number of dimensions added to the input time.""" + # TODO: not require this for Diffrax + def __call__(self, t: gt.QVecTime) -> PhaseSpacePosition: """Evaluate the interpolation. diff --git a/src/galax/dynamics/_dynamics/integrate/_base.py b/src/galax/dynamics/_dynamics/integrate/_base.py index f0843111..5b6530c4 100644 --- a/src/galax/dynamics/_dynamics/integrate/_base.py +++ b/src/galax/dynamics/_dynamics/integrate/_base.py @@ -1,15 +1,19 @@ __all__ = ["AbstractIntegrator"] import abc -from typing import Literal +from typing import Any, Literal, TypeVar import equinox as eqx +from plum import overload -from unxt import AbstractUnitSystem +import quaxed.array_api as xp +from unxt import AbstractUnitSystem, Quantity, to_units_value import galax.coordinates as gc import galax.typing as gt -from ._api import VectorField +from ._api import SaveT, VectorField + +Interp = TypeVar("Interp") class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc] @@ -24,8 +28,66 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis motion. They must not be stateful since they are used in a functional way. """ - # TODO: shape hint of the return type + InterpolantClass: eqx.AbstractClassVar[type[gc.PhaseSpacePositionInterpolant]] + @abc.abstractmethod + def _call_implementation( + self, + F: VectorField, + w0: gt.BatchVec6, + t0: gt.FloatScalar, + t1: gt.FloatScalar, + ts: gt.BatchVecTime, + /, + interpolated: Literal[False, True], + ) -> tuple[gt.BatchVecTime7, Any | None]: # TODO: type hint Interpolant + """Integrator implementation.""" + ... + + def _process_interpolation( + self, interp: Interp, w0: gt.BatchVec6 + ) -> tuple[Interp, int]: + """Process the interpolation. + + This is the default implementation and will probably need to be + overridden in a subclass. + """ + # Determine if an extra dimension was added to the output + added_ndim = int(w0.shape[:-1] in ((), (1,))) + # Return the interpolation and the number of added dimensions + return interp, added_ndim + + # ------------------------------------------------------------------------ + + @overload + def __call__( + self, + F: VectorField, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + t0: gt.FloatQScalar | gt.FloatScalar, + t1: gt.FloatQScalar | gt.FloatScalar, + /, + saveat: SaveT | None = None, + *, + units: AbstractUnitSystem, + interpolated: Literal[False] = False, + ) -> gc.PhaseSpacePosition: ... + + @overload + def __call__( + self, + F: VectorField, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + t0: gt.FloatQScalar | gt.FloatScalar, + t1: gt.FloatQScalar | gt.FloatScalar, + /, + saveat: SaveT | None = None, + *, + units: AbstractUnitSystem, + interpolated: Literal[True], + ) -> gc.InterpolatedPhaseSpacePosition: ... + + # TODO: shape hint of the return type def __call__( self, F: VectorField, @@ -131,4 +193,41 @@ def __call__( (2,) """ - ... + # Parse inputs + w0_: gt.Vec6 = ( + w0.w(units=units) if isinstance(w0, gc.AbstractPhaseSpacePosition) else w0 + ) + t0_: gt.VecTime = to_units_value(t0, units["time"]) + t1_: gt.VecTime = to_units_value(t1, units["time"]) + # Either save at `saveat` or at the final time. The final time is + # a scalar and the saveat is a vector, so a dimension is added. + saveat_ = ( + xp.asarray([t1_]) + if saveat is None + else to_units_value(saveat, units["time"]) + ) + + # Perform the integration + w, interp = self._call_implementation(F, w0_, t0_, t1_, saveat_, interpolated) + w = w[..., -1, :] if saveat is None else w # get rid of added dimension + + # Return + if interpolated: + interp, added_ndim = self._process_interpolation(interp, w0_) + + out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(saveat_, units["time"]), + interpolant=self.InterpolantClass( + interp, units=units, added_ndim=added_ndim + ), + ) + else: + out = gc.PhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(w[..., -1], units["time"]), + ) + + return out diff --git a/src/galax/dynamics/_dynamics/integrate/_builtin.py b/src/galax/dynamics/_dynamics/integrate/_builtin.py index 189b9108..3c61a39d 100644 --- a/src/galax/dynamics/_dynamics/integrate/_builtin.py +++ b/src/galax/dynamics/_dynamics/integrate/_builtin.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Mapping from dataclasses import KW_ONLY from functools import partial -from typing import Any, Literal, ParamSpec, TypeVar, final +from typing import Any, ClassVar, Literal, ParamSpec, TypeVar, final import diffrax import equinox as eqx @@ -12,14 +12,12 @@ import jax.numpy as jnp from diffrax import DenseInterpolation from jax._src.numpy.vectorize import _parse_gufunc_signature, _parse_input_dimensions -from plum import overload -import quaxed.array_api as xp -from unxt import AbstractUnitSystem, Quantity, to_units_value, unitsystem +from unxt import AbstractUnitSystem, Quantity, unitsystem import galax.coordinates as gc import galax.typing as gt -from ._api import SaveT, VectorField +from ._api import VectorField from ._base import AbstractIntegrator from galax.utils import ImmutableDict @@ -27,6 +25,53 @@ R = TypeVar("R") +class DiffraxInterpolant(eqx.Module): # type: ignore[misc]# + """Wrapper for ``diffrax.DenseInterpolation``.""" + + interpolant: DenseInterpolation + """:class:`diffrax.DenseInterpolation` object. + + This object is the result of the integration and can be used to evaluate the + interpolated solution at any time. However it does not understand units, so + the input is the time in ``units["time"]``. The output is a 6-vector of + (q, p) values in the units of the integrator. + """ + + units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem) + """The :class:`unxt.AbstractUnitSystem`. + + This is used to convert the time input to the interpolant and the phase-space + position output. + """ + + added_ndim: gt.Shape = eqx.field(static=True) + """The number of dimensions added to the output of the interpolation. + + This is used to reshape the output of the interpolation to match the batch + shape of the input to the integrator. The means of vectorizing the + interpolation means that the input must always be a batched array, resulting + in an extra dimension when the integration was on a scalar input. + """ + + # @partial(jax.jit) + def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition: + """Evaluate the interpolation.""" + # Parse t + t_ = jnp.atleast_1d(t.to_units_value(self.units["time"])) + + # Evaluate the interpolation + ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolant) + extra_dims: int = ys.ndim - 3 + self.added_ndim + (t_.ndim - t.ndim) + ys = ys[(0,) * extra_dims] + + # Construct and return the result + return gc.PhaseSpacePosition( + q=Quantity(ys[..., 0:3], self.units["length"]), + p=Quantity(ys[..., 3:6], self.units["speed"]), + t=t, + ) + + def vectorize( pyfunc: Callable[P, R], *, signature: str | None = None ) -> Callable[P, R]: @@ -101,6 +146,112 @@ class DiffraxIntegrator(AbstractIntegrator): Keyword arguments to pass to the solver. Default is ``{"scan_kind": "bounded"}``. + Examples + -------- + First some imports: + + >>> import quaxed.array_api as xp + >>> from unxt import Quantity + >>> from unxt.unitsystems import galactic + >>> import galax.coordinates as gc + >>> import galax.dynamics as gd + >>> import galax.potential as gp + + Then we define initial conditions: + + >>> w0 = gc.PhaseSpacePosition(q=Quantity([10., 0., 0.], "kpc"), + ... p=Quantity([0., 200., 0.], "km/s")) + + (Note that the ``t`` attribute is not used.) + + Now we can integrate the phase-space position for 1 Gyr, getting the + final position. The integrator accepts any function for the equations + of motion. Here we will reproduce what happens with orbit integrations. + + >>> pot = gp.HernquistPotential(m_tot=Quantity(1e12, "Msun"), + ... r_s=Quantity(5, "kpc"), units="galactic") + + >>> integrator = gd.integrate.DiffraxIntegrator() + >>> t0, t1 = Quantity(0, "Gyr"), Quantity(1, "Gyr") + >>> w = integrator(pot._integrator_F, w0, t0, t1, units=galactic) + >>> w + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity[...](value=f64[], unit=Unit("Myr")) + ) + >>> w.shape + () + + Instead of just returning the final position, we can get the state of + the system at any times ``saveat``: + + >>> ts = Quantity(xp.linspace(0, 1, 10), "Gyr") # 10 steps + >>> ws = integrator(pot._integrator_F, w0, t0, t1, + ... saveat=ts, units=galactic) + >>> ws + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity[...](value=f64[10], unit=Unit("Myr")) + ) + >>> ws.shape + (10,) + + In all these examples the integrator was used to integrate a single + position. The integrator can also be used to integrate a batch of + initial conditions at once, returning a batch of final conditions (or a + batch of conditions at the requested times): + + >>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"), + ... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s")) + >>> ws = integrator(pot._integrator_F, w0, t0, t1, units=galactic) + >>> ws.shape + (2,) + + A cool feature of the integrator is that it can return an interpolated + solution. + + >>> w = integrator(pot._integrator_F, w0, t0, t1, saveat=ts, units=galactic, + ... interpolated=True) + >>> type(w) + + + The interpolated solution can be evaluated at any time in the domain to get + the phase-space position at that time: + + >>> t = Quantity(xp.e, "Gyr") + >>> w(t) + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1], unit=Unit("Gyr")) + ) + + The interpolant is vectorized: + + >>> t = Quantity(xp.linspace(0, 1, 100), "Gyr") + >>> w(t) + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) + ) + + And it works on batches: + + >>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"), + ... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s")) + >>> ws = integrator(pot._integrator_F, w0, t0, t1, units=galactic, + ... interpolated=True) + >>> ws.shape + (2,) + >>> w(t) + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) + ) """ _: KW_ONLY @@ -119,6 +270,10 @@ class DiffraxIntegrator(AbstractIntegrator): default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict ) + InterpolantClass: ClassVar[type[gc.PhaseSpacePositionInterpolant]] = ( # type: ignore[misc] + DiffraxInterpolant + ) + # ===================================================== # Call @@ -169,271 +324,16 @@ def solve_diffeq( return w, interp - @overload - def __call__( - self, - F: VectorField, - w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, - t0: gt.FloatQScalar | gt.FloatScalar, - t1: gt.FloatQScalar | gt.FloatScalar, - /, - saveat: SaveT | None = None, - *, - units: AbstractUnitSystem, - interpolated: Literal[False] = False, - ) -> gc.PhaseSpacePosition: ... - - @overload - def __call__( - self, - F: VectorField, - w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, - t0: gt.FloatQScalar | gt.FloatScalar, - t1: gt.FloatQScalar | gt.FloatScalar, - /, - saveat: SaveT | None = None, - *, - units: AbstractUnitSystem, - interpolated: Literal[True], - ) -> gc.InterpolatedPhaseSpacePosition: ... - - def __call__( - self, - F: VectorField, - w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, - t0: gt.FloatQScalar | gt.FloatScalar, - t1: gt.FloatQScalar | gt.FloatScalar, - /, - saveat: SaveT | None = None, - *, - units: AbstractUnitSystem, - interpolated: Literal[False, True] = False, - ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: - """Run the integrator. - - Parameters - ---------- - F : VectorField, positional-only - The function to integrate. - w0 : AbstractPhaseSpacePosition | Array[float, (6,)], positional-only - Initial conditions ``[q, p]``. - t0, t1 : Quantity, positional-only - Initial and final times. - - saveat : (Quantity | Array)[float, (T,)] | None, optional - Times to return the computation. If `None`, the computation is - returned only at the final time. - - units : `unxt.AbstractUnitSystem` - The unit system to use. - interpolated : bool, keyword-only - Whether to return an interpolated solution. - - Returns - ------- - PhaseSpacePosition[float, (time, 7)] - The solution of the integrator [q, p, t], where q, p are the - generalized 3-coordinates. - - Examples - -------- - For this example, we will use the - :class:`~galax.integrate.DiffraxIntegrator` - - First some imports: - - >>> import quaxed.array_api as xp - >>> from unxt import Quantity - >>> import unxt.unitsystems as usx - >>> import galax.coordinates as gc - >>> import galax.dynamics as gd - >>> import galax.potential as gp - - Then we define initial conditions: - - >>> w0 = gc.PhaseSpacePosition(q=Quantity([10., 0., 0.], "kpc"), - ... p=Quantity([0., 200., 0.], "km/s")) - - (Note that the ``t`` attribute is not used.) - - Now we can integrate the phase-space position for 1 Gyr, getting the - final position. The integrator accepts any function for the equations - of motion. Here we will reproduce what happens with orbit integrations. - - >>> pot = gp.HernquistPotential(m_tot=Quantity(1e12, "Msun"), - ... r_s=Quantity(5, "kpc"), units="galactic") - - >>> integrator = gd.integrate.DiffraxIntegrator() - >>> t0, t1 = Quantity(0, "Gyr"), Quantity(1, "Gyr") - >>> w = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic) - >>> w - PhaseSpacePosition( - q=CartesianPosition3D( ... ), - p=CartesianVelocity3D( ... ), - t=Quantity[...](value=f64[], unit=Unit("Myr")) - ) - >>> w.shape - () - - Instead of just returning the final position, we can get the state of - the system at any times ``saveat``: - - >>> ts = Quantity(xp.linspace(0, 1, 10), "Gyr") # 10 steps - >>> ws = integrator(pot._integrator_F, w0, t0, t1, - ... saveat=ts, units=usx.galactic) - >>> ws - PhaseSpacePosition( - q=CartesianPosition3D( ... ), - p=CartesianVelocity3D( ... ), - t=Quantity[...](value=f64[10], unit=Unit("Myr")) - ) - >>> ws.shape - (10,) - - In all these examples the integrator was used to integrate a single - position. The integrator can also be used to integrate a batch of - initial conditions at once, returning a batch of final conditions (or a - batch of conditions at the requested times): - - >>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"), - ... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s")) - >>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic) - >>> ws.shape - (2,) - - A cool feature of the integrator is that it can return an interpolated - solution. - - >>> w = integrator(pot._integrator_F, w0, t0, t1, saveat=ts, units=usx.galactic, - ... interpolated=True) - >>> type(w) - - - The interpolated solution can be evaluated at any time in the domain to get - the phase-space position at that time: - - >>> t = Quantity(xp.e, "Gyr") - >>> w(t) - PhaseSpacePosition( - q=CartesianPosition3D( ... ), - p=CartesianVelocity3D( ... ), - t=Quantity[PhysicalType('time')](value=f64[1], unit=Unit("Gyr")) - ) - - The interpolant is vectorized: - - >>> t = Quantity(xp.linspace(0, 1, 100), "Gyr") - >>> w(t) - PhaseSpacePosition( - q=CartesianPosition3D( ... ), - p=CartesianVelocity3D( ... ), - t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) - ) - - And it works on batches: - - >>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"), - ... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s")) - >>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic, - ... interpolated=True) - >>> ws.shape - (2,) - >>> w(t) - PhaseSpacePosition( - q=CartesianPosition3D( ... ), - p=CartesianVelocity3D( ... ), - t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) - ) - """ - # Parse inputs - w0_: gt.Vec6 = ( - w0.w(units=units) if isinstance(w0, gc.AbstractPhaseSpacePosition) else w0 - ) - t0_: gt.VecTime = to_units_value(t0, units["time"]) - t1_: gt.VecTime = to_units_value(t1, units["time"]) - saveat_ = ( - xp.asarray([t1_]) - if saveat is None - else to_units_value(saveat, units["time"]) - ) - - # Perform the integration - w, interp = self._call_implementation(F, w0_, t0_, t1_, saveat_, interpolated) - w = w[..., -1, :] if saveat is None else w # TODO: undo this - - # Return - if interpolated: - # Determine if an extra dimension was added to the output - added_ndim = int(w0_.shape[:-1] in ((), (1,))) - # If one was, then the interpolant must be reshaped since the input - # was squeezed beforehand and the dimension must be added back. - if added_ndim == 1: - arr, narr = eqx.partition(interp, eqx.is_array) - arr = jax.tree_util.tree_map(lambda x: x[None], arr) - interp = eqx.combine(arr, narr) - - out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T) - q=Quantity(w[..., 0:3], units["length"]), - p=Quantity(w[..., 3:6], units["speed"]), - t=Quantity(saveat_, units["time"]), - interpolant=DiffraxInterpolant( - interp, units=units, added_ndim=added_ndim - ), - ) - else: - out = gc.PhaseSpacePosition( # shape = (*batch, T) - q=Quantity(w[..., 0:3], units["length"]), - p=Quantity(w[..., 3:6], units["speed"]), - t=Quantity(w[..., -1], units["time"]), - ) - - return out - - -class DiffraxInterpolant(eqx.Module): # type: ignore[misc]# - """Wrapper for ``diffrax.DenseInterpolation``.""" - - interpolant: DenseInterpolation - """:class:`diffrax.DenseInterpolation` object. - - This object is the result of the integration and can be used to evaluate the - interpolated solution at any time. However it does not understand units, so - the input is the time in ``units["time"]``. The output is a 6-vector of - (q, p) values in the units of the integrator. - """ - - units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem) - """The :class:`unxt.AbstractUnitSystem`. - - This is used to convert the time input to the interpolant and the phase-space - position output. - """ - - added_ndim: gt.Shape = eqx.field(static=True) - """The number of dimensions added to the output of the interpolation. - - This is used to reshape the output of the interpolation to match the batch - shape of the input to the integrator. The means of vectorizing the - interpolation means that the input must always be a batched array, resulting - in an extra dimension when the integration was on a scalar input. - """ - - # @partial(jax.jit) - def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition: - """Evaluate the interpolation.""" - # Parse t - t_ = jnp.atleast_1d(t.to_units_value(self.units["time"])) - # t_ = eqx.error_if(t_, xp.any(t_self.interpolant.t1), "t>t1") # noqa: ERA001 - - # Evaluate the interpolation - ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolant) - extra_dims: int = ys.ndim - 3 + self.added_ndim + (t_.ndim - t.ndim) - ys = ys[(0,) * extra_dims] - - # Construct and return the result - return gc.PhaseSpacePosition( - q=Quantity(ys[..., 0:3], self.units["length"]), - p=Quantity(ys[..., 3:6], self.units["speed"]), - t=t, - ) + def _process_interpolation( + self, interp: DenseInterpolation, w0: gt.BatchVec6 + ) -> tuple[DenseInterpolation, int]: + # Determine if an extra dimension was added to the output + added_ndim = int(w0.shape[:-1] in ((), (1,))) + # If one was, then the interpolant must be reshaped since the input + # was squeezed beforehand and the dimension must be added back. + if added_ndim == 1: + arr, narr = eqx.partition(interp, eqx.is_array) + arr = jax.tree_util.tree_map(lambda x: x[None], arr) + interp = eqx.combine(arr, narr) + + return interp, added_ndim diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_fardal15.py b/src/galax/dynamics/_dynamics/mockstream/df/_fardal15.py index e1401975..9260d93e 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_fardal15.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_fardal15.py @@ -300,15 +300,37 @@ def lagrange_points( potential : `galax.potential.AbstractPotentialBase` The gravitational potential of the host. x: Quantity[float, (3,), "length"] - 3d position (x, y, z) + Cartesian 3D position ($x$, $y$, $z$) v: Quantity[float, (3,), "speed"] - 3d velocity (v_x, v_y, v_z) + Cartesian 3D velocity ($v_x$, $v_y$, $v_z$) prog_mass: Quantity[float, (), "mass"] Cluster mass. t: Quantity[float, (), "time"] Time. - """ - r_hat = x / xp.linalg.vector_norm(x) + + Returns + ------- + L_1, L_2: Quantity[float, (3,), "length"] + The lagrange points L_1 and L_2. + + Examples + -------- + >>> from unxt import Quantity + >>> import galax.potential as gp + + >>> pot = gp.MilkyWayPotential() + >>> x = Quantity(xp.asarray([8.0, 0.0, 0.0]), "kpc") + >>> v = Quantity(xp.asarray([0.0, 220.0, 0.0]), "km/s") + >>> prog_mass = Quantity(1e4, "Msun") + >>> t = Quantity(0.0, "Gyr") + + >>> L1, L2 = lagrange_points(pot, x, v, prog_mass, t) + >>> L1 + Quantity['length'](Array([7.97070926, 0. , 0. ], dtype=float64), unit='kpc') + >>> L2 + Quantity['length'](Array([8.02929074, 0. , 0. ], dtype=float64), unit='kpc') + """ # noqa: E501 + r_hat = x / xp.linalg.vector_norm(x, axis=-1, keepdims=True) r_t = tidal_radius(potential, x, v, prog_mass, t) L_1 = x - r_hat * r_t # close L_2 = x + r_hat * r_t # far