From 4a7fb9dbd3606dc16683a80e782af4384e184e9f Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 11 Jun 2024 13:52:27 -0400 Subject: [PATCH] feat: CompositePSP and cleanup (#343) * fix: mockstream q interleaving * feat: CompositePhaseSpacePosition * feat: df sample returns CompositePSP * refactor: cleanup Signed-off-by: nstarman --- src/galax/coordinates/__init__.pyi | 2 + src/galax/coordinates/_psp/base_composite.py | 29 ++- src/galax/coordinates/_psp/core.py | 194 ++++++++++++++++-- src/galax/coordinates/_psp/utils.py | 33 +-- .../dynamics/_dynamics/mockstream/core.py | 5 +- .../dynamics/_dynamics/mockstream/df/_base.py | 37 ++-- .../_dynamics/mockstream/df/_fardal.py | 14 +- 7 files changed, 231 insertions(+), 83 deletions(-) diff --git a/src/galax/coordinates/__init__.pyi b/src/galax/coordinates/__init__.pyi index 79a774cc..2996bcde 100644 --- a/src/galax/coordinates/__init__.pyi +++ b/src/galax/coordinates/__init__.pyi @@ -9,6 +9,7 @@ __all__ = [ "AbstractPhaseSpacePosition", "AbstractCompositePhaseSpacePosition", "PhaseSpacePosition", + "CompositePhaseSpacePosition", "InterpolatedPhaseSpacePosition", "PhaseSpacePositionInterpolant", "ComponentShapeTuple", @@ -20,6 +21,7 @@ from ._psp import ( AbstractCompositePhaseSpacePosition, AbstractPhaseSpacePosition, ComponentShapeTuple, + CompositePhaseSpacePosition, InterpolatedPhaseSpacePosition, PhaseSpacePosition, PhaseSpacePositionInterpolant, diff --git a/src/galax/coordinates/_psp/base_composite.py b/src/galax/coordinates/_psp/base_composite.py index 440d9e2d..05a040d7 100644 --- a/src/galax/coordinates/_psp/base_composite.py +++ b/src/galax/coordinates/_psp/base_composite.py @@ -186,12 +186,39 @@ class of the target position class is used. Examples -------- - TODO + >>> from unxt import Quantity + >>> import coordinax as cx + >>> import galax.coordinates as gc + We define a composite phase-space position with two components. + Every component is a phase-space position in Cartesian coordinates. + + >>> psp1 = gc.PhaseSpacePosition(q=Quantity([1, 2, 3], "m"), + ... p=Quantity([4, 5, 6], "m/s"), + ... t=Quantity(7.0, "s")) + >>> psp2 = gc.PhaseSpacePosition(q=Quantity([1.5, 2.5, 3.5], "m"), + ... p=Quantity([4.5, 5.5, 6.5], "m/s"), + ... t=Quantity(6.0, "s")) + >>> cpsp = gc.CompositePhaseSpacePosition(psp1=psp1, psp2=psp2) + + We can transform the composite phase-space position to a new position class. + + >>> cx.represent_as(cpsp, cx.CylindricalPosition) + CompositePhaseSpacePosition({'psp1': PhaseSpacePosition( + q=CylindricalPosition( ... ), + p=CylindricalVelocity( ... ), + t=Quantity... + ), + 'psp2': PhaseSpacePosition( + q=CylindricalPosition( ... ), + p=CylindricalVelocity( ... ), + t=... + )}) """ differential_cls = ( position_cls.differential_cls if differential is None else differential ) + # TODO: can we use `replace`? return type(psp)( **{k: represent_as(v, position_cls, differential_cls) for k, v in psp.items()} ) diff --git a/src/galax/coordinates/_psp/core.py b/src/galax/coordinates/_psp/core.py index 70458f66..7c7fc1e6 100644 --- a/src/galax/coordinates/_psp/core.py +++ b/src/galax/coordinates/_psp/core.py @@ -1,17 +1,22 @@ """galax: Galactic Dynamics in Jax.""" -__all__ = ["PhaseSpacePosition"] +__all__ = ["PhaseSpacePosition", "CompositePhaseSpacePosition"] +from collections.abc import Iterable from typing import Any, NamedTuple, final import equinox as eqx -import jax.numpy as jnp +import jax.tree_util as jtu +from jaxtyping import Array, Int, PyTree, Shaped from typing_extensions import override -from coordinax import AbstractPosition3D, AbstractVelocity3D +import coordinax as cx +import quaxed.array_api as xp +import quaxed.numpy as jnp from unxt import Quantity import galax.typing as gt +from .base_composite import AbstractCompositePhaseSpacePosition from .base_psp import AbstractPhaseSpacePosition from .utils import _p_converter, _q_converter from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape @@ -46,18 +51,18 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): Parameters ---------- - q : :class:`~vector.AbstractPosition3D` + q : :class:`~coordinax.AbstractPosition3D` A 3-vector of the positions, allowing for batched inputs. This - parameter accepts any 3-vector, e.g. :class:`~vector.SphericalPosition`, + parameter accepts any 3-vector, e.g. :class:`~coordinax.SphericalPosition`, or any input that can be used to make a - :class:`~vector.CartesianPosition3D` via - :meth:`vector.AbstractPosition3D.constructor`. - p : :class:`~vector.AbstractVelocity3D` + :class:`~coordinax.CartesianPosition3D` via + :meth:`coordinax.AbstractPosition3D.constructor`. + p : :class:`~coordinax.AbstractVelocity3D` A 3-vector of the conjugate specific momenta at positions ``q``, allowing for batched inputs. This parameter accepts any 3-vector - differential, e.g. :class:`~vector.SphericalVelocity`, or any input - that can be used to make a :class:`~vector.CartesianVelocity3D` via - :meth:`vector.CartesianVelocity3D.constructor`. + differential, e.g. :class:`~coordinax.SphericalVelocity`, or any input + that can be used to make a :class:`~coordinax.CartesianVelocity3D` via + :meth:`coordinax.CartesianVelocity3D.constructor`. t : Quantity[float, (*batch,), 'time'] | None The time corresponding to the positions. @@ -70,18 +75,18 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): We assume the following imports: >>> from unxt import Quantity - >>> from coordinax import CartesianPosition3D, CartesianVelocity3D - >>> from galax.coordinates import PhaseSpacePosition + >>> import coordinax as cx + >>> import galax.coordinates as gc We can create a phase-space position: - >>> q = CartesianPosition3D(x=Quantity(1, "m"), y=Quantity(2, "m"), - ... z=Quantity(3, "m")) - >>> p = CartesianVelocity3D(d_x=Quantity(4, "m/s"), d_y=Quantity(5, "m/s"), - ... d_z=Quantity(6, "m/s")) + >>> q = cx.CartesianPosition3D(x=Quantity(1, "m"), y=Quantity(2, "m"), + ... z=Quantity(3, "m")) + >>> p = cx.CartesianVelocity3D(d_x=Quantity(4, "m/s"), d_y=Quantity(5, "m/s"), + ... d_z=Quantity(6, "m/s")) >>> t = Quantity(7.0, "s") - >>> psp = PhaseSpacePosition(q=q, p=p, t=t) + >>> psp = gc.PhaseSpacePosition(q=q, p=p, t=t) >>> psp PhaseSpacePosition( q=CartesianPosition3D( @@ -99,8 +104,8 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): Note that both `q` and `p` have convenience converters, allowing them to accept a variety of inputs when constructing a - :class:`~vector.CartesianPosition3D` or - :class:`~vector.CartesianVelocity3D`, respectively. For example, + :class:`~coordinax.CartesianPosition3D` or + :class:`~coordinax.CartesianVelocity3D`, respectively. For example, >>> psp2 = PhaseSpacePosition(q=Quantity([1, 2, 3], "m"), ... p=Quantity([4, 5, 6], "m/s"), t=t) @@ -109,13 +114,13 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): """ - q: AbstractPosition3D = eqx.field(converter=_q_converter) + q: cx.AbstractPosition3D = eqx.field(converter=_q_converter) """Positions, e.g CartesianPosition3D. This is a 3-vector with a batch shape allowing for vector inputs. """ - p: AbstractVelocity3D = eqx.field(converter=_p_converter) + p: cx.AbstractVelocity3D = eqx.field(converter=_p_converter) r"""Conjugate momenta, e.g. CartesianVelocity3D. This is a 3-vector with a batch shape allowing for vector inputs. @@ -164,3 +169,148 @@ def wt(self, *, units: Any) -> gt.BatchVec7: self.t, self.t is None, "No time defined for phase-space position" ) return super().wt(units=units) + + +############################################################################## + + +def _concat(values: Iterable[PyTree], time_sorter: Int[Array, "..."]) -> PyTree: + return jtu.tree_map( + lambda *xs: xp.concat(tuple(jnp.atleast_1d(x) for x in xs), axis=-1)[ + ..., time_sorter + ], + *values, + ) + + +@final +class CompositePhaseSpacePosition(AbstractCompositePhaseSpacePosition): + r"""Composite Phase-Space Position with time. + + The phase-space position is a point in the 7-dimensional phase space + :math:`\mathbb{R}^7` of a dynamical system. It is composed of the position + :math:`\boldsymbol{q}`, the time :math:`t`, and the conjugate momentum + :math:`\boldsymbol{p}`. + + This class has the same constructor semantics as `dict`. + + Parameters + ---------- + psps: dict | tuple, optional positional-only + initialize from a (key, value) mapping or tuple. + **kwargs : AbstractPhaseSpacePosition + The name=value pairs of the phase-space positions. + + Notes + ----- + - `q`, `p`, and `t` are a concatenation of all the constituent phase-space + positions, sorted by `t`. + - The batch shape of `q`, `p`, and `t` are broadcast together. + + Examples + -------- + We assume the following imports: + + >>> from unxt import Quantity + >>> import coordinax as cx + >>> import galax.coordinates as gc + + We can create a phase-space position. Here we will use the convenience + constructors for Cartesian positions and velocities. To see the full + constructor, see :class:`~galax.coordinates.PhaseSpacePosition`. + + >>> psp1 = gc.PhaseSpacePosition(q=Quantity([1, 2, 3], "m"), + ... p=Quantity([4, 5, 6], "m/s"), + ... t=Quantity(7.0, "s")) + >>> psp2 = gc.PhaseSpacePosition(q=Quantity([1.5, 2.5, 3.5], "m"), + ... p=Quantity([4.5, 5.5, 6.5], "m/s"), + ... t=Quantity(6.0, "s")) + + We can create a composite phase-space position from these two phase-space + positions: + + >>> cpsp = gc.CompositePhaseSpacePosition(psp1=psp1, psp2=psp2) + >>> cpsp + CompositePhaseSpacePosition({'psp1': PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity... + ), + 'psp2': PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity... + )}) + + The individual phase-space positions can be accessed via the keys: + + >>> cpsp["psp1"] + PhaseSpacePosition( + q=CartesianPosition3D( ... ), + p=CartesianVelocity3D( ... ), + t=Quantity... + ) + + The ``q``, ``p``, and ``t`` attributes are the concatenation of the + constituent phase-space positions, sorted by ``t``. Note that in this + example, the time of ``psp2`` is earlier than ``psp1``. + + >>> cpsp.t + Quantity['time'](Array([6., 7.], dtype=float64), unit='s') + + >>> cpsp.q.x + Quantity['length'](Array([1.5, 1. ], dtype=float64), unit='m') + + >>> cpsp.p.d_x + Quantity['speed'](Array([4.5, 4. ], dtype=float64), unit='m / s') + + We can transform the composite phase-space position to a new position class. + + >>> cx.represent_as(cpsp, cx.CylindricalPosition) + CompositePhaseSpacePosition({'psp1': PhaseSpacePosition( + q=CylindricalPosition( ... ), + p=CylindricalVelocity( ... ), + t=Quantity... + ), + 'psp2': PhaseSpacePosition( + q=CylindricalPosition( ... ), + p=CylindricalVelocity( ... ), + t=... + )}) + """ + + _time_sorter: Shaped[Array, "alltimes"] + + def __init__( + self, + psps: dict[str, AbstractPhaseSpacePosition] + | tuple[tuple[str, AbstractPhaseSpacePosition], ...] = (), + /, + **kwargs: AbstractPhaseSpacePosition, + ) -> None: + super().__init__(psps, **kwargs) + + # TODO: check up on the shapes + + # Construct time sorter + ts = xp.concat([jnp.atleast_1d(psp.t) for psp in self.values()], axis=0) + self._time_sorter = xp.argsort(ts) + + @property + def q(self) -> cx.AbstractPosition3D: + """Positions.""" + # TODO: get AbstractPosition to work with `stack` directly + return _concat((x.q for x in self.values()), self._time_sorter) + + @property + def p(self) -> cx.AbstractVelocity3D: + """Conjugate momenta.""" + # TODO: get AbstractPosition to work with `stack` directly + return _concat((x.p for x in self.values()), self._time_sorter) + + @property + def t(self) -> Shaped[Quantity["time"], "..."]: + """Times.""" + return xp.concat([jnp.atleast_1d(psp.t) for psp in self.values()], axis=0)[ + self._time_sorter + ] diff --git a/src/galax/coordinates/_psp/utils.py b/src/galax/coordinates/_psp/utils.py index 63baeb4e..fd96d5a3 100644 --- a/src/galax/coordinates/_psp/utils.py +++ b/src/galax/coordinates/_psp/utils.py @@ -2,48 +2,17 @@ __all__: list[str] = [] -from collections.abc import Sequence -from functools import partial, singledispatch +from functools import singledispatch from typing import Any, Protocol, cast, runtime_checkable import astropy.coordinates as apyc -import jax -from jaxtyping import Array, Shaped import coordinax as cx import quaxed.array_api as xp -from unxt import Quantity import galax.typing as gt -@partial(jax.jit, static_argnames="axis") -def interleave_concat( - arrays: Sequence[Shaped[Array, "shape"]] | Sequence[Shaped[Quantity, "shape"]], - /, - axis: int, -) -> Shaped[Array, "..."] | Shaped[Quantity, "..."]: # TODO: shape hint - # Check if input is a non-empty list - if not arrays or not isinstance(arrays, Sequence): - msg = "Input should be a non-empty sequence of arrays." - raise ValueError(msg) - - # Ensure all arrays have the same shape - shape0 = arrays[0].shape - if not all(arr.shape == shape0 for arr in arrays): - msg = "All arrays must have the same shape." - raise ValueError(msg) - - # Stack the arrays along a new axis to prepare for interleaving - axis = axis % len(shape0) # allows for negative axis - stacked = xp.stack(arrays, axis=axis + 1) - - # Flatten the new axis by interleaving values - return xp.reshape( - stacked, (*shape0[:axis], len(arrays) * shape0[axis], *shape0[axis + 1 :]) - ) - - @runtime_checkable class HasShape(Protocol): """Protocol for an object with a shape attribute.""" diff --git a/src/galax/dynamics/_dynamics/mockstream/core.py b/src/galax/dynamics/_dynamics/mockstream/core.py index b32cb11d..fd5a8bc2 100644 --- a/src/galax/dynamics/_dynamics/mockstream/core.py +++ b/src/galax/dynamics/_dynamics/mockstream/core.py @@ -25,7 +25,6 @@ _p_converter, _q_converter, getitem_vec1time_index, - interleave_concat, ) from galax.utils._shape import batched_shape, vector_batched_shape @@ -111,10 +110,10 @@ def __init__( @property def q(self) -> cx.AbstractPosition3D: """Positions.""" - # TODO: interleave by time # TODO: get AbstractPosition to work with `stack` directly return jtu.tree_map( - lambda *x: interleave_concat(x, axis=-1), *(x.q for x in self.values()) + lambda *x: xp.concat(x, axis=-1)[..., self._time_sorter], + *(x.q for x in self.values()), ) @property diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_base.py b/src/galax/dynamics/_dynamics/mockstream/df/_base.py index 2eaf5634..9adc0396 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_base.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_base.py @@ -16,11 +16,12 @@ import quaxed.array_api as xp from unxt import Quantity +import galax.coordinates as gc +import galax.potential as gp import galax.typing as gt from ._progenitor import ConstantMassProtenitor, ProgenitorMassCallable -from galax.dynamics._dynamics.mockstream.core import MockStream, MockStreamArm +from galax.dynamics._dynamics.mockstream.core import MockStreamArm from galax.dynamics._dynamics.orbit import Orbit -from galax.potential import AbstractPotentialBase Carry: TypeAlias = tuple[gt.LengthVec3, gt.SpeedVec3, gt.LengthVec3, gt.SpeedVec3] @@ -33,21 +34,21 @@ def sample( self, rng: PRNGKeyArray, # <\ parts of gala's ``prog_orbit`` - pot: AbstractPotentialBase, + pot: gp.AbstractPotentialBase, prog_orbit: Orbit, # /> /, prog_mass: gt.MassScalar | ProgenitorMassCallable, - ) -> MockStream: + ) -> gc.CompositePhaseSpacePosition: """Generate stream particle initial conditions. Parameters ---------- - rng : :class:`jaxtyping.PRNGKeyArray` + rng : :class:`jaxtyping.PRNGKeyArray`, positional-only Pseudo-random number generator. - pot : AbstractPotentialBase, positional-only + pot : :class:`~galax.potential.AbstractPotentialBase`, positional-only The potential of the host galaxy. - prog_orbit : Orbit, positional-only + prog_orbit : :class:`~galax.dynamics.Orbit`, positional-only The orbit of the progenitor. prog_mass : Quantity[float, (), 'mass'] | ProgenitorMassCallable @@ -55,7 +56,7 @@ def sample( Returns ------- - `galax.dynamics.MockStream` + `galax.coordinates.CompositePhaseSpacePosition` Phase-space positions of the leading and trailing arms. Examples @@ -77,8 +78,8 @@ def sample( # Progenitor positions and times. The orbit times are used as the # release times for the mock stream. prog_orbit = prog_orbit.represent_as(cx.CartesianPosition3D) - x = convert(prog_orbit.q, Quantity) - v = convert(prog_orbit.p, Quantity) + xs = convert(prog_orbit.q, Quantity) + vs = convert(prog_orbit.p, Quantity) ts = prog_orbit.t # Progenitor mass @@ -92,15 +93,15 @@ def sample( # conditions at each release time. def scan_fn(_: Carry, inputs: tuple[int, PRNGKeyArray]) -> tuple[Carry, Carry]: i, key = inputs - out = self._sample(key, pot, x[i], v[i], mprog(ts[i]), ts[i]) + out = self._sample(key, pot, xs[i], vs[i], mprog(ts[i]), ts[i]) return out, out # TODO: use ``jax.vmap`` instead of ``jax.lax.scan``? init_carry = ( - xp.zeros_like(x[0]), - xp.zeros_like(v[0]), - xp.zeros_like(x[0]), - xp.zeros_like(v[0]), + xp.zeros_like(xs[0]), + xp.zeros_like(vs[0]), + xp.zeros_like(xs[0]), + xp.zeros_like(vs[0]), ) subkeys = jr.split(rng, len(ts)) x_lead, v_lead, x_trail, v_trail = jax.lax.scan( @@ -120,14 +121,14 @@ def scan_fn(_: Carry, inputs: tuple[int, PRNGKeyArray]) -> tuple[Carry, Carry]: release_time=ts.to_units(pot.units["time"]), ) - return MockStream(lead=mock_lead, trail=mock_trail) + return gc.CompositePhaseSpacePosition(lead=mock_lead, trail=mock_trail) # TODO: keep units and PSP through this func @abc.abstractmethod def _sample( self, rng: PRNGKeyArray, - pot: AbstractPotentialBase, + pot: gp.AbstractPotentialBase, x: gt.LengthVec3, v: gt.SpeedVec3, prog_mass: gt.FloatQScalar, @@ -141,7 +142,7 @@ def _sample( ---------- rng : :class:`jaxtyping.PRNGKeyArray` Pseudo-random number generator. - pot : AbstractPotentialBase + pot : :class:`galax.potential.AbstractPotentialBase` The potential of the host galaxy. x : Quantity[float, (3,), "length"] 3d position (x, y, z) diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py b/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py index 9efffa24..ec5da619 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_fardal.py @@ -45,10 +45,10 @@ class FardalStreamDF(AbstractStreamDF): https://ui.adsabs.harvard.edu/abs/2015MNRAS.452..301F/abstract """ - @partial(jax.jit, static_argnums=(0,)) + @partial(jax.jit, inline=True) def _sample( self, - rng: PRNGKeyArray, + key: PRNGKeyArray, potential: AbstractPotentialBase, x: gt.LengthVec3, v: gt.SpeedVec3, @@ -57,7 +57,7 @@ def _sample( ) -> tuple[gt.LengthVec3, gt.SpeedVec3, gt.LengthVec3, gt.SpeedVec3]: """Generate stream particle initial conditions.""" # Random number generation - rng1, rng2, rng3, rng4 = jr.split(rng, 4) + key1, key2, key3, key4 = jr.split(key, 4) omega_val = orbital_angular_velocity_mag(x, v) @@ -77,10 +77,10 @@ def _sample( phi_hat = phi_vec / xp.linalg.vector_norm(phi_vec, axis=-1) # k vals - kr_samp = kr_bar + jr.normal(rng1, (1,)) * sigma_kr - kvphi_samp = kr_samp * (kvphi_bar + jr.normal(rng2, (1,)) * sigma_kvphi) - kz_samp = kz_bar + jr.normal(rng3, (1,)) * sigma_kz - kvz_samp = kvz_bar + jr.normal(rng4, (1,)) * sigma_kvz + kr_samp = kr_bar + jr.normal(key1, (1,)) * sigma_kr + kvphi_samp = kr_samp * (kvphi_bar + jr.normal(key2, (1,)) * sigma_kvphi) + kz_samp = kz_bar + jr.normal(key3, (1,)) * sigma_kz + kvz_samp = kvz_bar + jr.normal(key4, (1,)) * sigma_kvz # Trailing arm x_trail = x + r_tidal * (kr_samp * r_hat + kz_samp * z_hat)