Skip to content

Commit

Permalink
feat(PSP): vectors (GalacticDynamics#155)
Browse files Browse the repository at this point in the history
* refactor: move psp
* feat: add vector dependency
* feat: make PSP work with vectors

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Feb 26, 2024
1 parent 638a0d4 commit 89e3b2a
Show file tree
Hide file tree
Showing 16 changed files with 476 additions and 202 deletions.
9 changes: 6 additions & 3 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ in :mod:`galax.integrate`, but do so using the convenience interface available
on any Potential object through the
:func:`~galax.potential.AbstractPotential.integrate_orbit` method::

>>> import galax.dynamics as gd
>>> t = jnp.arange(0.0, 2.0, step=1/1000) # Gyr
>>> orbit = mw.integrate_orbit(psp.w(), t=t)
>>> orbit = gd.evaluate_orbit(mw, psp.w(units=mw.units), t=t)

By default, this method uses Leapfrog integration , which is a fast, symplectic
integration scheme. The returned object is an instance of the
Expand All @@ -129,7 +130,9 @@ phase-space positions at times::

>>> orbit
Orbit(
q=f64[2000,3], p=f64[2000,3], t=f64[2000], ...
q=Cartesian3DVector(
x=Quantity[PhysicalType('length')](value=f64[2000], unit=Unit("kpc")),
...

:class:`~galax.dynamics.Orbit` objects have many of their own useful methods for
performing common tasks, like plotting an orbit::
Expand All @@ -151,7 +154,7 @@ performing common tasks, like plotting an orbit::
mw = gp.MilkyWayPotential()
psp = gc.PhaseSpacePosition(pos=[-8.1, 0, 0.02] * u.kpc,
vel=[13, 245, 8.] * u.km/u.s)
orbit = mw.integrate_orbit(psp.w(), dt=1*u.Myr, t1=0, t2=2*u.Gyr)
orbit = gd.evaluate_orbit(psp.w(units=mw.units), dt=1*u.Myr, t1=0, t2=2*u.Gyr)

orbit.plot(['x', 'y'])

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"lazy_loader",
"quax >= 0.0.3",
"typing_extensions",
"vector @ git+https://github.com/GalacticDynamics/vector.git",
]
description = "Galactic Dynamix in Jax."
dynamic = ["version"]
Expand Down
94 changes: 64 additions & 30 deletions src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
import array_api_jax_compat as xp
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.typing import BatchFloatScalar, BatchVec3, BatchVec6
from jaxtyping import Shaped
from plum import convert

from jax_quantity import Quantity
from vector import (
Abstract3DVector,
Abstract3DVectorDifferential,
Cartesian3DVector,
CartesianDifferential3D,
)

from galax.typing import BatchQVec3, BatchVec6
from galax.units import UnitSystem

if TYPE_CHECKING:
Expand All @@ -22,16 +30,23 @@
class AbstractPhaseSpacePositionBase(eqx.Module, strict=True): # type: ignore[call-arg, misc]
"""Abstract base class for all the types of phase-space positions.
Parameters
----------
q : :class:`~vector.Abstract3DVector`
Positions.
p : :class:`~vector.Abstract3DVectorDifferential`
Conjugate momenta at positions ``q``.
See Also
--------
:class:`~galax.coordinates.AbstractPhaseSpacePosition`
:class:`~galax.coordinates.AbstractPhaseSpaceTimePosition`
"""

q: eqx.AbstractVar[Float[Array, "*#batch #time 3"]]
q: eqx.AbstractVar[Abstract3DVector]
"""Positions."""

p: eqx.AbstractVar[Float[Array, "*#batch #time 3"]]
p: eqx.AbstractVar[Abstract3DVectorDifferential]
"""Conjugate momenta at positions ``q``."""

# ==========================================================================
Expand Down Expand Up @@ -72,7 +87,7 @@ def full_shape(self) -> tuple[int, ...]:
# ==========================================================================
# Convenience methods

def w(self, *, units: UnitSystem | None = None) -> BatchVec6:
def w(self, *, units: UnitSystem) -> BatchVec6:
"""Phase-space position as an Array[float, (*batch, Q + P)].
This is the full phase-space position, not including the time.
Expand All @@ -85,23 +100,22 @@ def w(self, *, units: UnitSystem | None = None) -> BatchVec6:
Returns
-------
w : Array[float, (*batch, Q + P)]
The phase-space position.
The phase-space position as a 6-vector in Cartesian coordinates.
"""
if units is not None:
msg = "units not yet implemented."
raise NotImplementedError(msg)

batch_shape, component_shapes = self._shape_tuple
q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1])
p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2])
return xp.concat((q, p), axis=-1)
batch_shape, comp_shapes = self._shape_tuple
q = xp.broadcast_to(convert(self.q, Quantity), (*batch_shape, comp_shapes[0]))
p = xp.broadcast_to(
convert(self.p.represent_as(CartesianDifferential3D, self.q), Quantity),
(*batch_shape, comp_shapes[1]),
)
return xp.concat((q.decompose(units).value, p.decompose(units).value), axis=-1)

# ==========================================================================
# Dynamical quantities

# TODO: property?
@partial(jax.jit)
def kinetic_energy(self) -> BatchFloatScalar:
def kinetic_energy(self) -> Shaped[Quantity["specific energy"], "*batch"]:
r"""Return the specific kinetic energy.
.. math::
Expand All @@ -114,11 +128,11 @@ def kinetic_energy(self) -> BatchFloatScalar:
The kinetic energy.
"""
# TODO: use a ``norm`` function so that this works for non-Cartesian.
return 0.5 * xp.sum(self.p**2, axis=-1)
return 0.5 * self.p.norm(self.q) ** 2

# TODO: property?
@partial(jax.jit)
def angular_momentum(self) -> BatchVec3:
def angular_momentum(self) -> BatchQVec3:
r"""Compute the angular momentum.
.. math::
Expand All @@ -131,23 +145,43 @@ def angular_momentum(self) -> BatchVec3:
Returns
-------
L : Array[float, (*batch,3)]
Array of angular momentum vectors.
Array of angular momentum vectors in Cartesian coordinates.
Examples
--------
We assume the following imports
>>> import numpy as np
>>> import astropy.units as u
>>> from galax.coordinates import PhaseSpacePosition
>>> from jax_quantity import Quantity
>>> from galax.coordinates import PhaseSpacePosition
We can compute the angular momentum of a single object
>>> pos = np.array([1., 0, 0]) * u.au
>>> vel = np.array([0, 2*np.pi, 0]) * u.au/u.yr
>>> w = PhaseSpacePosition(pos, vel)
>>> w.angular_momentum()
Array([0. , 0. , 6.28318531], dtype=float64)
>>> pos = Quantity([1., 0, 0], "au")
>>> vel = Quantity([0, 2., 0], "au/yr")
>>> w = PhaseSpacePosition(pos, vel)
>>> w.angular_momentum()
Quantity['diffusivity'](Array([0., 0., 2.], dtype=float64), unit='AU2 / yr')
"""
# TODO: when q, p are not Cartesian.
return jnp.cross(self.q, self.p)
# TODO: keep as a vector.
# https://github.com/GalacticDynamics/vector/issues/27
q = convert(self.q, Quantity)
p = convert(self.p.represent_as(CartesianDifferential3D, self.q), Quantity)
return xp.linalg.cross(q, p)


# =============================================================================
# helper functions


def _q_converter(x: Any) -> Abstract3DVector:
"""Convert input to a 3D vector."""
return x if isinstance(x, Abstract3DVector) else Cartesian3DVector.constructor(x)


def _p_converter(x: Any) -> Abstract3DVectorDifferential:
"""Convert input to a 3D vector differential."""
return (
x
if isinstance(x, Abstract3DVectorDifferential)
else CartesianDifferential3D.constructor(x)
)
100 changes: 81 additions & 19 deletions src/galax/coordinates/_psp/psp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
from plum import convert

from .base import AbstractPhaseSpacePositionBase
from galax.typing import BatchFloatScalar, BroadBatchVec3, FloatScalar
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array
from jax_quantity import Quantity
from vector import Abstract3DVector, Abstract3DVectorDifferential

from .base import AbstractPhaseSpacePositionBase, _p_converter, _q_converter
from galax.typing import (
BatchableFloatOrIntScalarLike,
BatchFloatOrIntQScalar,
BatchFloatQScalar,
FloatScalar,
)
from galax.utils._shape import vector_batched_shape

if TYPE_CHECKING:
from typing import Self
Expand All @@ -28,12 +35,20 @@ class AbstractPhaseSpacePosition(AbstractPhaseSpacePositionBase):
The phase-space position is a point in the 6-dimensional phase space
:math:`\mathbb{R}^6` of a dynamical system. It is composed of the position
:math:`\boldsymbol{q}` and the conjugate momentum :math:`\boldsymbol{p}`.
Parameters
----------
q : :class:`~vector.Abstract3DVector`
Positions.
p : :class:`~vector.Abstract3DVectorDifferential`
Conjugate momenta at positions ``q``.
"""

q: eqx.AbstractVar[Float[Array, "*#batch #time 3"]]
# TODO: hint shape Float[Array, "*#batch #time 3"]
q: eqx.AbstractVar[Abstract3DVector]
"""Positions."""

p: eqx.AbstractVar[Float[Array, "*#batch #time 3"]]
p: eqx.AbstractVar[Abstract3DVectorDifferential]
"""Conjugate momenta at positions ``q``."""

# ==========================================================================
Expand All @@ -48,8 +63,11 @@ def __getitem__(self, index: Any) -> "Self":
# Dynamical quantities

def potential_energy(
self, potential: "AbstractPotentialBase", /, t: FloatScalar
) -> BatchFloatScalar:
self,
potential: "AbstractPotentialBase",
/,
t: BatchFloatOrIntQScalar | BatchableFloatOrIntScalarLike,
) -> BatchFloatQScalar:
r"""Return the specific potential energy.
.. math::
Expand All @@ -60,21 +78,22 @@ def potential_energy(
----------
potential : :class:`~galax.potential.AbstractPotentialBase`
The potential object to compute the energy from.
t : float
t : :class:`jax_quantity.Quantity[float, (*batch,), "time"]`
The time at which to compute the potential energy at the given
positions.
Returns
-------
E : Array[float, (*batch,)]
E : Quantity[float, (*batch,), "specific energy"]
The specific potential energy.
"""
return potential.potential_energy(self.q, t=t)
x = convert(self.q, Quantity).value # Cartesian positions
return potential.potential_energy(x, t=t)

@partial(jax.jit)
def energy(
self, potential: "AbstractPotentialBase", /, t: FloatScalar
) -> BatchFloatScalar:
) -> BatchFloatQScalar:
r"""Return the specific total energy.
.. math::
Expand All @@ -86,13 +105,13 @@ def energy(
----------
potential : :class:`~galax.potential.AbstractPotentialBase`
The potential object to compute the energy from.
t : float
t : Quantity[float, (*batch,), "time"]
The time at which to compute the potential energy at the given
positions.
Returns
-------
E : Array[float, (*batch,)]
E : Quantity[float, (*batch,), "specific energy"]
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential, t=t)
Expand All @@ -109,19 +128,62 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
:math:`\\mathbb{R}^6` of a dynamical system. It is composed of the position
:math:`\boldsymbol{q}` and the conjugate momentum :math:`\boldsymbol{p}`.
Parameters
----------
q : :class:`~vector.Abstract3DVector`
Positions.
p : :class:`~vector.Abstract3DVectorDifferential`
Conjugate momenta at positions ``q``.
See Also
--------
:class:`~galax.coordinates.PhaseSpaceTimePosition`
A phase-space position with time.
Examples
--------
We assume the following imports:
>>> from jax_quantity import Quantity
>>> from vector import Cartesian3DVector, CartesianDifferential3D
>>> from galax.coordinates import PhaseSpacePosition
We can create a phase-space position:
>>> q = Cartesian3DVector(x=Quantity(1, "m"), y=Quantity(2, "m"),
... z=Quantity(3, "m"))
>>> p = CartesianDifferential3D(d_x=Quantity(4, "m/s"), d_y=Quantity(5, "m/s"),
... d_z=Quantity(6, "m/s"))
>>> pos = PhaseSpacePosition(q=q, p=p)
>>> pos
PhaseSpacePosition(
q=Cartesian3DVector(
x=Quantity[PhysicalType('length')](value=f64[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f64[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f64[], unit=Unit("m"))
),
p=CartesianDifferential3D(
d_x=Quantity[PhysicalType({'speed', 'velocity'})](
value=f64[], unit=Unit("m / s")
),
d_y=Quantity[PhysicalType({'speed', 'velocity'})](
value=f64[], unit=Unit("m / s")
),
d_z=Quantity[PhysicalType({'speed', 'velocity'})](
value=f64[], unit=Unit("m / s")
)
)
)
"""

q: BroadBatchVec3 = eqx.field(converter=converter_float_array)
q: Abstract3DVector = eqx.field(converter=_q_converter)
"""Positions (x, y, z).
This is a 3-vector with a batch shape allowing for vector inputs.
"""

p: BroadBatchVec3 = eqx.field(converter=converter_float_array)
p: Abstract3DVectorDifferential = eqx.field(converter=_p_converter)
r"""Conjugate momenta (v_x, v_y, v_z).
This is a 3-vector with a batch shape allowing for vector inputs.
Expand All @@ -133,7 +195,7 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
@property
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int]]:
"""Batch, component shape."""
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
qbatch, qshape = vector_batched_shape(self.q)
pbatch, pshape = vector_batched_shape(self.p)
batch_shape = jnp.broadcast_shapes(qbatch, pbatch)
return batch_shape, qshape + pshape
Loading

0 comments on commit 89e3b2a

Please sign in to comment.