From d30501627fd70b5021b1ecb06295bab1a5f13fd9 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 21 May 2024 12:19:42 -0400 Subject: [PATCH] feat(psp): Composite PSP (#301) * feat: composite MockStream class Signed-off-by: nstarman --- pyproject.toml | 2 +- src/galax/coordinates/__init__.pyi | 4 + src/galax/coordinates/_psp/__init__.py | 6 + src/galax/coordinates/_psp/base.py | 128 ++---------- src/galax/coordinates/_psp/base_composite.py | 197 ++++++++++++++++++ src/galax/coordinates/_psp/base_psp.py | 126 +++++++++++ src/galax/coordinates/_psp/core.py | 2 +- src/galax/coordinates/_psp/interp.py | 3 +- src/galax/coordinates/_psp/operator_compat.py | 2 +- src/galax/coordinates/_psp/utils.py | 35 +++- src/galax/coordinates/operators/_rotating.py | 2 +- src/galax/dynamics/__init__.pyi | 3 +- src/galax/dynamics/_compat.py | 16 +- .../dynamics/_dynamics/mockstream/core.py | 71 ++++++- .../dynamics/_dynamics/mockstream/df/_base.py | 24 +-- .../mockstream/mockstream_generator.py | 63 +++--- .../dynamics/_dynamics/mockstream/utils.py | 18 +- tests/smoke/coordinates/test_package.py | 4 +- tests/smoke/dynamics/test_package.py | 1 + .../psp/{test_base.py => test_base_psp.py} | 0 tests/unit/coordinates/psp/test_psp.py | 2 +- tests/unit/dynamics/test_orbit.py | 2 +- 22 files changed, 513 insertions(+), 198 deletions(-) create mode 100644 src/galax/coordinates/_psp/base_composite.py create mode 100644 src/galax/coordinates/_psp/base_psp.py rename tests/unit/coordinates/psp/{test_base.py => test_base_psp.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 6cb80dfb..12ebb7e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,7 @@ filterwarnings = [ "error", "ignore:ast\\.Str is deprecated:DeprecationWarning", "ignore:numpy\\.ndarray size changed:RuntimeWarning", - "ignore:Passing arguments 'a':DeprecationWarning", # upstream: diffrax + "ignore:Passing arguments 'a':DeprecationWarning", # TODO: from diffrax ] log_cli_level = "INFO" markers = [ diff --git a/src/galax/coordinates/__init__.pyi b/src/galax/coordinates/__init__.pyi index 00adc5d6..79a774cc 100644 --- a/src/galax/coordinates/__init__.pyi +++ b/src/galax/coordinates/__init__.pyi @@ -5,7 +5,9 @@ __all__ = [ # Modules "operators", # Phase-space positions + "AbstractBasePhaseSpacePosition", "AbstractPhaseSpacePosition", + "AbstractCompositePhaseSpacePosition", "PhaseSpacePosition", "InterpolatedPhaseSpacePosition", "PhaseSpacePositionInterpolant", @@ -14,6 +16,8 @@ __all__ = [ from . import operators from ._psp import ( + AbstractBasePhaseSpacePosition, + AbstractCompositePhaseSpacePosition, AbstractPhaseSpacePosition, ComponentShapeTuple, InterpolatedPhaseSpacePosition, diff --git a/src/galax/coordinates/_psp/__init__.py b/src/galax/coordinates/_psp/__init__.py index 4b20897c..0a747854 100644 --- a/src/galax/coordinates/_psp/__init__.py +++ b/src/galax/coordinates/_psp/__init__.py @@ -2,6 +2,8 @@ from . import ( base, + base_composite, + base_psp, compat_apy, # noqa: F401 core, interp, @@ -9,12 +11,16 @@ utils, ) from .base import * +from .base_composite import * +from .base_psp import * from .core import * from .interp import * from .utils import * __all__: list[str] = [] __all__ += base.__all__ +__all__ += base_psp.__all__ +__all__ += base_composite.__all__ __all__ += core.__all__ __all__ += interp.__all__ __all__ += utils.__all__ diff --git a/src/galax/coordinates/_psp/base.py b/src/galax/coordinates/_psp/base.py index 5441d750..9d67f2fc 100644 --- a/src/galax/coordinates/_psp/base.py +++ b/src/galax/coordinates/_psp/base.py @@ -1,10 +1,9 @@ -"""galax: Galactic Dynamics in Jax.""" +"""ABC for phase-space positions.""" -__all__ = ["AbstractPhaseSpacePosition", "ComponentShapeTuple"] +__all__ = ["AbstractBasePhaseSpacePosition", "ComponentShapeTuple"] from abc import abstractmethod from collections.abc import Mapping -from dataclasses import replace from functools import partial from typing import TYPE_CHECKING, Any, NamedTuple, cast @@ -18,14 +17,11 @@ from unxt import Quantity, unitsystem import galax.typing as gt -from .utils import getitem_broadscalartime_index from galax.utils.dataclasses import dataclass_items if TYPE_CHECKING: from typing import Self - from galax.potential import AbstractPotentialBase - class ComponentShapeTuple(NamedTuple): """Component shape of the phase-space position.""" @@ -40,22 +36,14 @@ class ComponentShapeTuple(NamedTuple): """Shape of the time component.""" -class AbstractPhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-arg, misc] - r"""Abstract base class of phase-space positions. +class AbstractBasePhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-arg, misc] + r"""ABC underlying phase-space positions and their composites. - The phase-space position is a point in the 7-dimensional phase space + The phase-space position is a point in the 3+3+1-dimensional phase space :math:`\mathbb{R}^7` of a dynamical system. It is composed of the position - :math:`\boldsymbol{q}`, the conjugate momentum :math:`\boldsymbol{p}`, and - the time :math:`t`. - - Parameters - ---------- - q : :class:`~vector.Abstract3DVector` - Positions. - p : :class:`~vector.Abstract3DVectorDifferential` - Conjugate momenta at positions ``q``. - t : :class:`~unxt.Quantity` - Time corresponding to the positions and momenta. + :math:`\boldsymbol{q}\in\mathbb{R}^3`, the conjugate momentum + :math:`\boldsymbol{p}\in\mathbb{R}^3`, and the time + :math:`t\in\mathbb{R}^1`. """ q: eqx.AbstractVar[cx.Abstract3DVector] @@ -73,20 +61,20 @@ class AbstractPhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call- @classmethod @dispatch # type: ignore[misc] def constructor( - cls: "type[AbstractPhaseSpacePosition]", obj: Mapping[str, Any], / - ) -> "AbstractPhaseSpacePosition": + cls: "type[AbstractBasePhaseSpacePosition]", obj: Mapping[str, Any], / + ) -> "AbstractBasePhaseSpacePosition": """Construct from a mapping. Parameters ---------- - cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`] + cls : type[:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`] The class to construct. obj : Mapping[str, Any] The mapping from which to construct. Returns ------- - :class:`~galax.coordinates.AbstractPhaseSpacePosition` + :class:`~galax.coordinates.AbstractBasePhaseSpacePosition` The constructed phase-space position. Examples @@ -173,12 +161,10 @@ def __len__(self) -> int: """Return the number of particles.""" return self.shape[0] + @abstractmethod def __getitem__(self, index: Any) -> "Self": """Return a new object with the given slice applied.""" - # Compute subindex - subindex = getitem_broadscalartime_index(index, self.t) - # Apply slice - return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex]) + ... # ========================================================================== # Further Array properties @@ -348,6 +334,7 @@ class of the target position class is used. """ # noqa: E501 return cast("Self", cx.represent_as(self, position_cls, differential_cls)) + @abstractmethod def to_units(self, units: Any) -> "Self": """Return with the components transformed to the given unit system. @@ -385,13 +372,7 @@ def to_units(self, units: Any) -> "Self": t=Quantity[...](value=f64[], unit=Unit("yr")) ) """ - usys = unitsystem(units) - return replace( - self, - q=self.q.to_units(usys), - p=self.p.to_units(usys), - t=self.t.to_units(usys["time"]) if self.t is not None else None, - ) + ... # ========================================================================== # Dynamical quantities @@ -584,22 +565,22 @@ def angular_momentum(self) -> gt.BatchQVec3: # Register additional constructors -@AbstractPhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001 +@AbstractBasePhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001 def constructor( - cls: type[AbstractPhaseSpacePosition], obj: AbstractPhaseSpacePosition, / -) -> AbstractPhaseSpacePosition: - """Construct from a `AbstractPhaseSpacePosition`. + cls: type[AbstractBasePhaseSpacePosition], obj: AbstractBasePhaseSpacePosition, / +) -> AbstractBasePhaseSpacePosition: + """Construct from a `AbstractBasePhaseSpacePosition`. Parameters ---------- - cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`] + cls : type[:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`] The class to construct. - obj : :class:`~galax.coordinates.AbstractPhaseSpacePosition` + obj : :class:`~galax.coordinates.AbstractBasePhaseSpacePosition` The phase-space position object from which to construct. Returns ------- - :class:`~galax.coordinates.AbstractPhaseSpacePosition` + :class:`~galax.coordinates.AbstractBasePhaseSpacePosition` The constructed phase-space position. Raises @@ -644,66 +625,3 @@ def constructor( return obj return cls(**dict(dataclass_items(obj))) - - -# ----------------------------------------------- -# Register AbstractPhaseSpacePosition with `coordinax.represent_as` -@dispatch # type: ignore[misc] -def represent_as( - psp: AbstractPhaseSpacePosition, - position_cls: type[cx.AbstractVectorBase], - /, - differential: type[cx.AbstractVectorDifferential] | None = None, -) -> AbstractPhaseSpacePosition: - """Return with the components transformed. - - Parameters - ---------- - psp : :class:`~galax.coordinates.AbstractPhaseSpacePosition` - The phase-space position. - position_cls : type[:class:`~vector.AbstractVectorBase`] - The target position class. - differential : type[:class:`~vector.AbstractVectorDifferential`], optional - The target differential class. If `None` (default), the differential - class of the target position class is used. - - Examples - -------- - With the following imports: - - >>> from unxt import Quantity - >>> import coordinax as cx - >>> from galax.coordinates import PhaseSpacePosition - - We can create a phase-space position and convert it to a 6-vector: - - >>> psp = PhaseSpacePosition(q=Quantity([1, 2, 3], "kpc"), - ... p=Quantity([4, 5, 6], "km/s"), - ... t=Quantity(0, "Gyr")) - >>> psp.w(units="galactic") - Array([1. , 2. , 3. , 0.00409085, 0.00511356, 0.00613627], dtype=float64) - - We can also convert it to a different representation: - - >>> psp.represent_as(cx.CylindricalVector) - PhaseSpacePosition( q=CylindricalVector(...), - p=CylindricalDifferential(...), - t=Quantity[...](value=f64[], unit=Unit("Gyr")) ) - - We can also convert it to a different representation with a different - differential class: - - >>> psp.represent_as(cx.LonLatSphericalVector, cx.LonCosLatSphericalDifferential) - PhaseSpacePosition( q=LonLatSphericalVector(...), - p=LonCosLatSphericalDifferential(...), - t=Quantity[...](value=f64[], unit=Unit("Gyr")) ) - - """ - differential_cls = ( - position_cls.differential_cls if differential is None else differential - ) - return replace( - psp, - q=psp.q.represent_as(position_cls), - p=psp.p.represent_as(differential_cls, psp.q), - ) diff --git a/src/galax/coordinates/_psp/base_composite.py b/src/galax/coordinates/_psp/base_composite.py new file mode 100644 index 00000000..4e7137db --- /dev/null +++ b/src/galax/coordinates/_psp/base_composite.py @@ -0,0 +1,197 @@ +"""ABC for composite phase-space positions.""" + +__all__ = ["AbstractCompositePhaseSpacePosition"] + +from abc import abstractmethod +from collections.abc import Mapping +from types import MappingProxyType +from typing import TYPE_CHECKING, Any + +from jaxtyping import Shaped +from plum import dispatch + +import coordinax as cx +import quaxed.numpy as qnp +from unxt import Quantity + +import galax.typing as gt +from .base import AbstractBasePhaseSpacePosition, ComponentShapeTuple +from galax.utils import ImmutableDict +from galax.utils._misc import first +from galax.utils.dataclasses import dataclass_items + +if TYPE_CHECKING: + from typing import Self + + +# Note: cannot have `strict=True` because of inheriting from ImmutableDict. +class AbstractCompositePhaseSpacePosition( + ImmutableDict[AbstractBasePhaseSpacePosition], # TODO: as a TypeVar + AbstractBasePhaseSpacePosition, + strict=False, # type: ignore[call-arg] +): + r"""Abstract base class of composite phase-space positions. + + The composite phase-space position is a point in the 3 spatial + 3 kinematic + + 1 time -dimensional phase space :math:`\mathbb{R}^7` of a dynamical + system. It is composed of multiple phase-space positions, each of which + represents a component of the system. + + The input signature matches that of :class:`dict` (and + :class:`~galax.utils.ImmutableDict`), so you can pass in the components as + keyword arguments or as a dictionary. + + The components are stored as a dictionary and can be key accessed. However, + the composite phase-space position itself acts as a single + `AbstractBasePhaseSpacePosition` object, so you can access the composite + positions, velocities, and times as if they were a single object. In this + base class the composition of the components is abstract and must be + implemented in the subclasses. + + Examples + -------- + >>> from dataclasses import replace + >>> import quaxed.array_api as xp + >>> from unxt import Quantity + >>> import coordinax as cx + >>> import galax.coordinates as gc + + >>> def stack(vs: list[cx.AbstractVectorBase]) -> cx.AbstractVectorBase: + ... comps = {k: xp.stack([getattr(v, k) for v in vs], axis=-1) + ... for k in vs[0].components} + ... return replace(vs[0], **comps) + + >>> class CompositePhaseSpacePosition(gc.AbstractCompositePhaseSpacePosition): + ... @property + ... def q(self) -> cx.Abstract3DVector: + ... return stack([psp.q for psp in self.values()]) + ... + ... @property + ... def p(self) -> cx.Abstract3DVector: + ... return stack([psp.p for psp in self.values()]) + ... + ... @property + ... def t(self) -> Shaped[Quantity["time"], "..."]: + ... return stack([psp.t for psp in self.values()]) + + >>> psp1 = gc.PhaseSpacePosition(q=Quantity([1, 2, 3], "kpc"), + ... p=Quantity([4, 5, 6], "km/s"), + ... t=Quantity(7, "Myr")) + >>> psp2 = gc.PhaseSpacePosition(q=Quantity([10, 20, 30], "kpc"), + ... p=Quantity([40, 50, 60], "km/s"), + ... t=Quantity(7, "Myr")) + + >>> c_psp = CompositePhaseSpacePosition(psp1=psp1, psp2=psp2) + >>> c_psp["psp1"] is psp1 + True + + >>> c_psp.q + Cartesian3DVector( + x=Quantity[...](value=f64[2], unit=Unit("kpc")), + y=Quantity[...](value=f64[2], unit=Unit("kpc")), + z=Quantity[...](value=f64[2], unit=Unit("kpc")) + ) + + >>> c_psp.p.d_x + Quantity['speed'](Array([ 4., 40.], dtype=float64), unit='km / s') + """ + + _data: dict[str, AbstractBasePhaseSpacePosition] + + def __init__( + self, + psps: ( + dict[str, AbstractBasePhaseSpacePosition] + | tuple[tuple[str, AbstractBasePhaseSpacePosition], ...] + ) = (), + /, + **kwargs: AbstractBasePhaseSpacePosition, + ) -> None: + super().__init__(psps, **kwargs) # <- ImmutableDict.__init__ + + @property + @abstractmethod + def q(self) -> cx.Abstract3DVector: + """Positions.""" + + @property + @abstractmethod + def p(self) -> cx.Abstract3DVector: + """Conjugate momenta.""" + + @property + @abstractmethod + def t(self) -> Shaped[Quantity["time"], "..."]: + """Times.""" + + # ========================================================================== + # Array properties + + def __getitem__(self, key: Any) -> "Self": + """Get item from the key.""" + # Get specific item + if isinstance(key, str): + return self._data[key] + + # Get from each value, e.g. a slice + return type(self)(**{k: v[key] for k, v in self.items()}) + + @property + def _shape_tuple(self) -> tuple[gt.Shape, ComponentShapeTuple]: + """Batch and component shapes.""" + # TODO: speed up + batch_shape = qnp.broadcast_shapes(*[psp.shape for psp in self.values()]) + batch_shape = (*batch_shape[:-1], len(self) * batch_shape[-1]) + shape = first(self.values())._shape_tuple[1] # noqa: SLF001 + return batch_shape, shape + + # ========================================================================== + # Convenience methods + + def to_units(self, units: Any) -> "Self": + return type(self)(**{k: v.to_units(units) for k, v in self.items()}) + + # =============================================================== + # Collection methods + + @property + def shapes(self) -> Mapping[str, tuple[int, ...]]: + """Get the shapes of the components.""" + return MappingProxyType({k: v.shape for k, v in dataclass_items(self)}) + + +# ============================================================================= +# helper functions + + +# Register AbstractCompositePhaseSpacePosition with `coordinax.represent_as` +@dispatch # type: ignore[misc] +def represent_as( + psp: AbstractCompositePhaseSpacePosition, + position_cls: type[cx.AbstractVectorBase], + /, + differential: type[cx.AbstractVectorDifferential] | None = None, +) -> AbstractCompositePhaseSpacePosition: + """Return with the components transformed. + + Parameters + ---------- + psp : :class:`~galax.coordinates.AbstractCompositePhaseSpacePosition` + The phase-space position. + position_cls : type[:class:`~vector.AbstractVectorBase`] + The target position class. + differential : type[:class:`~vector.AbstractVectorDifferential`], optional + The target differential class. If `None` (default), the differential + class of the target position class is used. + + Examples + -------- + TODO + + """ + differential_cls = ( + position_cls.differential_cls if differential is None else differential + ) + return type(psp)( + **{k: represent_as(v, position_cls, differential_cls) for k, v in psp.items()} + ) diff --git a/src/galax/coordinates/_psp/base_psp.py b/src/galax/coordinates/_psp/base_psp.py new file mode 100644 index 00000000..9488d86a --- /dev/null +++ b/src/galax/coordinates/_psp/base_psp.py @@ -0,0 +1,126 @@ +"""galax: Galactic Dynamics in Jax.""" + +__all__ = ["AbstractPhaseSpacePosition"] + +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from plum import dispatch + +import coordinax as cx +from unxt import unitsystem + +from .base import AbstractBasePhaseSpacePosition +from .utils import getitem_broadscalartime_index + +if TYPE_CHECKING: + from typing import Self + + +class AbstractPhaseSpacePosition(AbstractBasePhaseSpacePosition): + r"""Abstract base class of phase-space positions. + + The phase-space position is a point in the 3+3+1-dimensional phase space + :math:`\mathbb{R}^7` of a dynamical system. It is composed of the position + :math:`\boldsymbol{q}\in\mathbb{R}^3`, the conjugate momentum + :math:`\boldsymbol{p}\in\mathbb{R}^3`, and the time + :math:`t\in\mathbb{R}^1`. + + Parameters + ---------- + q : :class:`~vector.Abstract3DVector` + Positions. + p : :class:`~vector.Abstract3DVectorDifferential` + Conjugate momenta at positions ``q``. + t : :class:`~unxt.Quantity` + Time corresponding to the positions and momenta. + """ + + # ========================================================================== + # Array properties + + def __getitem__(self, index: Any) -> "Self": + """Return a new object with the given slice applied.""" + # Compute subindex + subindex = getitem_broadscalartime_index(index, self.t) + # Apply slice + return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex]) + + # ========================================================================== + # Convenience methods + + def to_units(self, units: Any) -> "Self": + usys = unitsystem(units) + return replace( + self, + q=self.q.to_units(usys), + p=self.p.to_units(usys), + t=self.t.to_units(usys["time"]) if self.t is not None else None, + ) + + +# ============================================================================= +# helper functions + + +# ----------------------------------------------- +# Register AbstractPhaseSpacePosition with `coordinax.represent_as` +@dispatch # type: ignore[misc] +def represent_as( + psp: AbstractPhaseSpacePosition, + position_cls: type[cx.AbstractVectorBase], + /, + differential: type[cx.AbstractVectorDifferential] | None = None, +) -> AbstractPhaseSpacePosition: + """Return with the components transformed. + + Parameters + ---------- + psp : :class:`~galax.coordinates.AbstractPhaseSpacePosition` + The phase-space position. + position_cls : type[:class:`~vector.AbstractVectorBase`] + The target position class. + differential : type[:class:`~vector.AbstractVectorDifferential`], optional + The target differential class. If `None` (default), the differential + class of the target position class is used. + + Examples + -------- + With the following imports: + + >>> from unxt import Quantity + >>> import coordinax as cx + >>> from galax.coordinates import PhaseSpacePosition + + We can create a phase-space position and convert it to a 6-vector: + + >>> psp = PhaseSpacePosition(q=Quantity([1, 2, 3], "kpc"), + ... p=Quantity([4, 5, 6], "km/s"), + ... t=Quantity(0, "Gyr")) + >>> psp.w(units="galactic") + Array([1. , 2. , 3. , 0.00409085, 0.00511356, 0.00613627], dtype=float64) + + We can also convert it to a different representation: + + >>> psp.represent_as(cx.CylindricalVector) + PhaseSpacePosition( q=CylindricalVector(...), + p=CylindricalDifferential(...), + t=Quantity[...](value=f64[], unit=Unit("Gyr")) ) + + We can also convert it to a different representation with a different + differential class: + + >>> psp.represent_as(cx.LonLatSphericalVector, cx.LonCosLatSphericalDifferential) + PhaseSpacePosition( q=LonLatSphericalVector(...), + p=LonCosLatSphericalDifferential(...), + t=Quantity[...](value=f64[], unit=Unit("Gyr")) ) + + """ + differential_cls = ( + position_cls.differential_cls if differential is None else differential + ) + return replace( + psp, + q=psp.q.represent_as(position_cls), + p=psp.p.represent_as(differential_cls, psp.q), + ) diff --git a/src/galax/coordinates/_psp/core.py b/src/galax/coordinates/_psp/core.py index bfd3c634..d65800b2 100644 --- a/src/galax/coordinates/_psp/core.py +++ b/src/galax/coordinates/_psp/core.py @@ -12,7 +12,7 @@ from unxt import Quantity import galax.typing as gt -from .base import AbstractPhaseSpacePosition +from .base_psp import AbstractPhaseSpacePosition from .utils import _p_converter, _q_converter from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape diff --git a/src/galax/coordinates/_psp/interp.py b/src/galax/coordinates/_psp/interp.py index 11fb5695..095c575b 100644 --- a/src/galax/coordinates/_psp/interp.py +++ b/src/galax/coordinates/_psp/interp.py @@ -11,7 +11,8 @@ from unxt import AbstractUnitSystem, Quantity import galax.typing as gt -from .base import AbstractPhaseSpacePosition, ComponentShapeTuple +from .base import ComponentShapeTuple +from .base_psp import AbstractPhaseSpacePosition from .core import PhaseSpacePosition from .utils import _p_converter, _q_converter from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape diff --git a/src/galax/coordinates/_psp/operator_compat.py b/src/galax/coordinates/_psp/operator_compat.py index 70ef032b..871faa06 100644 --- a/src/galax/coordinates/_psp/operator_compat.py +++ b/src/galax/coordinates/_psp/operator_compat.py @@ -21,7 +21,7 @@ from coordinax.operators._base import op_call_dispatch from unxt import Quantity -from galax.coordinates._psp.base import AbstractPhaseSpacePosition +from .base_psp import AbstractPhaseSpacePosition vec_matmul = quaxify(jnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)")) diff --git a/src/galax/coordinates/_psp/utils.py b/src/galax/coordinates/_psp/utils.py index 4c3d6874..07073ab8 100644 --- a/src/galax/coordinates/_psp/utils.py +++ b/src/galax/coordinates/_psp/utils.py @@ -2,20 +2,51 @@ __all__: list[str] = [] -from functools import singledispatch +from collections.abc import Sequence +from functools import partial, singledispatch from typing import Any, Protocol, cast, runtime_checkable import astropy.coordinates as apyc +import jax +from jaxtyping import Array, Shaped import coordinax as cx import quaxed.array_api as xp +from unxt import Quantity import galax.typing as gt +@partial(jax.jit, static_argnames="axis") +def interleave_concat( + arrays: Sequence[Shaped[Array, "shape"]] | Sequence[Shaped[Quantity, "shape"]], + /, + axis: int, +) -> Shaped[Array, "..."] | Shaped[Quantity, "..."]: # TODO: shape hint + # Check if input is a non-empty list + if not arrays or not isinstance(arrays, Sequence): + msg = "Input should be a non-empty sequence of arrays." + raise ValueError(msg) + + # Ensure all arrays have the same shape + shape0 = arrays[0].shape + if not all(arr.shape == shape0 for arr in arrays): + msg = "All arrays must have the same shape." + raise ValueError(msg) + + # Stack the arrays along a new axis to prepare for interleaving + axis = axis % len(shape0) # allows for negative axis + stacked = xp.stack(arrays, axis=axis + 1) + + # Flatten the new axis by interleaving values + return xp.reshape( + stacked, (*shape0[:axis], len(arrays) * shape0[axis], *shape0[axis + 1 :]) + ) + + @runtime_checkable class HasShape(Protocol): - """Protocol for a shaped object.""" + """Protocol for an object with a shape attribute.""" shape: gt.Shape diff --git a/src/galax/coordinates/operators/_rotating.py b/src/galax/coordinates/operators/_rotating.py index 14b0b39f..36af117e 100644 --- a/src/galax/coordinates/operators/_rotating.py +++ b/src/galax/coordinates/operators/_rotating.py @@ -15,7 +15,7 @@ from coordinax.operators._base import op_call_dispatch from unxt import Quantity -from galax.coordinates._psp.base import AbstractPhaseSpacePosition +from galax.coordinates._psp.base_psp import AbstractPhaseSpacePosition def rot_z( diff --git a/src/galax/dynamics/__init__.pyi b/src/galax/dynamics/__init__.pyi index 890dd7b1..bfef77e6 100644 --- a/src/galax/dynamics/__init__.pyi +++ b/src/galax/dynamics/__init__.pyi @@ -12,6 +12,7 @@ __all__ = [ # integrate "evaluate_orbit", # mockstream + "MockStreamArm", "MockStream", "MockStreamGenerator", # mockstream.df @@ -22,6 +23,6 @@ __all__ = [ from ._dynamics import integrate, mockstream from ._dynamics.base import AbstractOrbit from ._dynamics.integrate._funcs import evaluate_orbit -from ._dynamics.mockstream import MockStream, MockStreamGenerator +from ._dynamics.mockstream import MockStream, MockStreamArm, MockStreamGenerator from ._dynamics.mockstream.df import AbstractStreamDF, FardalStreamDF from ._dynamics.orbit import InterpolatedOrbit, Orbit diff --git a/src/galax/dynamics/_compat.py b/src/galax/dynamics/_compat.py index 549f4c59..cc9fb90c 100644 --- a/src/galax/dynamics/_compat.py +++ b/src/galax/dynamics/_compat.py @@ -40,13 +40,13 @@ def constructor(_: type[gdx.Orbit], obj: gd.Orbit, /) -> gdx.Orbit: # MockStream -@conversion_method(type_from=gd.MockStream, type_to=gdx.MockStream) # type: ignore[misc] -def gala_mockstream_to_galax_mockstream(obj: gd.MockStream, /) -> gdx.MockStream: - """`gala.dynamics.MockStream` -> `galax.dynamics.MockStream`.""" - return gdx.MockStream(q=obj.pos, p=obj.vel, release_time=obj.release_time) +@conversion_method(type_from=gd.MockStream, type_to=gdx.MockStreamArm) # type: ignore[misc] +def gala_mockstream_to_galax_mockstream(obj: gd.MockStream, /) -> gdx.MockStreamArm: + """`gala.dynamics.MockStreamArm` -> `galax.dynamics.MockStreamArm`.""" + return gdx.MockStreamArm(q=obj.pos, p=obj.vel, release_time=obj.release_time) -@gdx.MockStream.constructor._f.register # type: ignore[misc] # noqa: SLF001 -def constructor(_: type[gdx.MockStream], obj: gd.MockStream, /) -> gdx.MockStream: - """Construct a :mod:`galax` MockStream from a :mod:`gala` one.""" - return cast(gdx.MockStream, gala_mockstream_to_galax_mockstream(obj)) +@gdx.MockStreamArm.constructor._f.register # type: ignore[misc] # noqa: SLF001 +def constructor(_: type[gdx.MockStreamArm], obj: gd.MockStream, /) -> gdx.MockStreamArm: + """Construct a :mod:`galax` MockStreamArm from a :mod:`gala` one.""" + return cast(gdx.MockStreamArm, gala_mockstream_to_galax_mockstream(obj)) diff --git a/src/galax/dynamics/_dynamics/mockstream/core.py b/src/galax/dynamics/_dynamics/mockstream/core.py index 6e5a1e2d..efd5b2e2 100644 --- a/src/galax/dynamics/_dynamics/mockstream/core.py +++ b/src/galax/dynamics/_dynamics/mockstream/core.py @@ -1,22 +1,31 @@ -"""galax: Galactic Dynamix in Jax.""" +"""Mock stellar streams.""" -__all__ = ["MockStream"] +__all__ = ["MockStreamArm", "MockStream"] from dataclasses import replace from typing import TYPE_CHECKING, Any, final import equinox as eqx import jax.numpy as jnp +import jax.tree_util as jtu +from jaxtyping import Array, Shaped +import coordinax as cx +import quaxed.array_api as xp from coordinax import Abstract3DVector, Abstract3DVectorDifferential from unxt import Quantity import galax.typing as gt -from galax.coordinates import AbstractPhaseSpacePosition, ComponentShapeTuple +from galax.coordinates import ( + AbstractCompositePhaseSpacePosition, + AbstractPhaseSpacePosition, + ComponentShapeTuple, +) from galax.coordinates._psp.utils import ( _p_converter, _q_converter, getitem_vec1time_index, + interleave_concat, ) from galax.utils._shape import batched_shape, vector_batched_shape @@ -25,8 +34,8 @@ @final -class MockStream(AbstractPhaseSpacePosition): - """Mock stream object. +class MockStreamArm(AbstractPhaseSpacePosition): + """Component of a mock stream object. Parameters ---------- @@ -76,3 +85,55 @@ def __getitem__(self, index: Any) -> "Self": t=self.t[subindex], release_time=self.release_time[subindex], ) + + +############################################################################## + + +@final +class MockStream(AbstractCompositePhaseSpacePosition): + _time_sorter: Shaped[Array, "alltimes"] + + def __init__( + self, + psps: dict[str, MockStreamArm] | tuple[tuple[str, MockStreamArm], ...] = (), + /, + **kwargs: MockStreamArm, + ) -> None: + super().__init__(psps, **kwargs) + + # TODO: check up on the shapes + + # Construct time sorter + ts = xp.concat([psp.t for psp in self.values()], axis=0) + self._time_sorter = xp.argsort(ts) + + @property + def q(self) -> cx.Abstract3DVector: + """Positions.""" + # TODO: interleave by time + # TODO: get AbstractVector to work with `stack` directly + return jtu.tree_map( + lambda *x: interleave_concat(x, axis=-1), *(x.q for x in self.values()) + ) + + @property + def p(self) -> cx.Abstract3DVector: + """Conjugate momenta.""" + # TODO: get AbstractVector to work with `stack` directly + return jtu.tree_map( + lambda *x: xp.concat(x, axis=-1)[..., self._time_sorter], + *(x.p for x in self.values()), + ) + + @property + def t(self) -> Shaped[Quantity["time"], "..."]: + """Times.""" + return xp.concat([psp.t for psp in self.values()], axis=0)[self._time_sorter] + + @property + def release_time(self) -> Shaped[Quantity["time"], "..."]: + """Release times.""" + return xp.concat([psp.release_time for psp in self.values()], axis=0)[ + self._time_sorter + ] diff --git a/src/galax/dynamics/_dynamics/mockstream/df/_base.py b/src/galax/dynamics/_dynamics/mockstream/df/_base.py index bc85b70d..870f6732 100644 --- a/src/galax/dynamics/_dynamics/mockstream/df/_base.py +++ b/src/galax/dynamics/_dynamics/mockstream/df/_base.py @@ -17,23 +17,13 @@ import galax.typing as gt from ._progenitor import ConstantMassProtenitor, ProgenitorMassCallable -from galax.dynamics._dynamics.mockstream.core import MockStream +from galax.dynamics._dynamics.mockstream.core import MockStreamArm from galax.dynamics._dynamics.orbit import Orbit from galax.potential import AbstractPotentialBase -Wif: TypeAlias = tuple[ - gt.LengthVec3, - gt.LengthVec3, - gt.SpeedVec3, - gt.SpeedVec3, -] +Wif: TypeAlias = tuple[gt.LengthVec3, gt.LengthVec3, gt.SpeedVec3, gt.SpeedVec3] Carry: TypeAlias = tuple[ - int, - jr.PRNG, - gt.LengthVec3, - gt.LengthVec3, - gt.SpeedVec3, - gt.SpeedVec3, + int, jr.PRNG, gt.LengthVec3, gt.LengthVec3, gt.SpeedVec3, gt.SpeedVec3 ] @@ -58,7 +48,7 @@ def sample( # /> /, prog_mass: gt.MassScalar | ProgenitorMassCallable, - ) -> tuple[MockStream, MockStream]: + ) -> tuple[MockStreamArm, MockStreamArm]: """Generate stream particle initial conditions. Parameters @@ -75,7 +65,7 @@ def sample( Returns ------- - mock_lead, mock_trail : MockStream + mock_lead, mock_trail : MockStreamArm Positions and velocities of the leading and trailing tails. """ # Progenitor positions and times. The orbit times are used as the @@ -110,13 +100,13 @@ def scan_fn(carry: Carry, t: gt.FloatQScalar) -> tuple[Carry, Wif]: ) x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts)[1] - mock_lead = MockStream( + mock_lead = MockStreamArm( q=x_lead.to_units(pot.units["length"]), p=v_lead.to_units(pot.units["speed"]), t=ts.to_units(pot.units["time"]), release_time=ts.to_units(pot.units["time"]), ) - mock_trail = MockStream( + mock_trail = MockStreamArm( q=x_trail.to_units(pot.units["length"]), p=v_trail.to_units(pot.units["speed"]), t=ts.to_units(pot.units["time"]), diff --git a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py index 4582c0fc..8d614edc 100644 --- a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py +++ b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py @@ -17,9 +17,9 @@ import galax.coordinates as gc import galax.typing as gt -from .core import MockStream +from .core import MockStream, MockStreamArm from .df import AbstractStreamDF, ProgenitorMassCallable -from .utils import cond_reverse, interleave_concat +from .utils import cond_reverse from galax.dynamics._dynamics.integrate._api import Integrator from galax.dynamics._dynamics.integrate._funcs import ( _default_integrator, @@ -75,7 +75,10 @@ def _progenitor_trajectory( @partial(jax.jit) def _run_scan( # TODO: output shape depends on the input shape - self, ts: gt.QVecTime, mock0_lead: MockStream, mock0_trail: MockStream + self, + ts: gt.QVecTime, + mock0_lead: MockStreamArm, + mock0_trail: MockStreamArm, ) -> tuple[gt.BatchVec6, gt.BatchVec6]: """Generate stellar stream by scanning over the release model/integration. @@ -121,7 +124,10 @@ def integ_ics(ics: gt.Vec6) -> gt.VecN: @partial(jax.jit) def _run_vmap( # TODO: output shape depends on the input shape - self, ts: gt.QVecTime, mock0_lead: MockStream, mock0_trail: MockStream + self, + ts: gt.QVecTime, + mock0_lead: MockStreamArm, + mock0_trail: MockStreamArm, ) -> tuple[gt.BatchVec6, gt.BatchVec6]: """Generate stellar stream by vmapping over the release model/integration. @@ -157,7 +163,7 @@ def run( prog_mass: gt.FloatQScalar | ProgenitorMassCallable, *, vmapped: bool | None = None, - ) -> tuple[MockStream, gc.PhaseSpacePosition]: + ) -> tuple[MockStreamArm, gc.PhaseSpacePosition]: """Generate mock stellar stream. Parameters @@ -196,7 +202,7 @@ def run( Returns ------- - mockstream : :class:`galax.dynamcis.MockStream` + mockstream : :class:`galax.dynamcis.MockStreamArm` Leading and/or trailing arms of the mock stream. prog_o : :class:`galax.coordinates.PhaseSpacePosition` The final phase-space(+time) position of the progenitor. @@ -225,6 +231,7 @@ def run( # Integrate the progenitor orbit, evaluating at the stripping times prog_o = self._progenitor_trajectory(w0, ts) + # TODO: here sep out lead vs trailing # Generate initial conditions from the DF, along the integrated # progenitor orbit. The release times are the stripping times. mock0_lead, mock0_trail = self.df.sample(rng, self.potential, prog_o, prog_mass) @@ -236,34 +243,20 @@ def run( t = xp.ones_like(ts) * ts.value[-1] # TODO: ensure this time is correct - # TODO: have a composite Stream object that has components, e.g. leading - # and trailing. - # TODO: move the leading vs trailing logic to the DF - if self.df.lead and self.df.trail: - axis = len(trail_arm_w.shape) - 2 - q = interleave_concat(trail_arm_w[:, 0:3], lead_arm_w[:, 0:3], axis=axis) - p = interleave_concat(trail_arm_w[:, 3:6], lead_arm_w[:, 3:6], axis=axis) - t = interleave_concat(t, t, axis=0) - release_time = interleave_concat( - mock0_lead.release_time, mock0_trail.release_time, axis=0 + comps = {} + if self.df.lead: + comps["lead"] = MockStreamArm( + q=Quantity(lead_arm_w[:, 0:3], self.units["length"]), + p=Quantity(lead_arm_w[:, 3:6], self.units["speed"]), + t=t, + release_time=mock0_lead.release_time, + ) + if self.df.trail: + comps["trail"] = MockStreamArm( + q=Quantity(trail_arm_w[:, 0:3], self.units["length"]), + p=Quantity(trail_arm_w[:, 3:6], self.units["speed"]), + t=t, + release_time=mock0_trail.release_time, ) - elif self.df.lead: - q = lead_arm_w[:, 0:3] - p = lead_arm_w[:, 3:6] - release_time = mock0_lead.release_time - elif self.df.trail: - q = trail_arm_w[:, 0:3] - p = trail_arm_w[:, 3:6] - release_time = mock0_trail.release_time - else: - msg = "You must generate either leading or trailing tails (or both!)" - raise ValueError(msg) - - mockstream = MockStream( - q=Quantity(q, self.units["length"]), - p=Quantity(p, self.units["speed"]), - t=t, - release_time=release_time, - ) - return mockstream, prog_o[-1] + return MockStream(comps), prog_o[-1] diff --git a/src/galax/dynamics/_dynamics/mockstream/utils.py b/src/galax/dynamics/_dynamics/mockstream/utils.py index cecff524..0c8a5af7 100644 --- a/src/galax/dynamics/_dynamics/mockstream/utils.py +++ b/src/galax/dynamics/_dynamics/mockstream/utils.py @@ -2,31 +2,15 @@ __all__: list[str] = [] -from functools import partial from typing import Any, Protocol, TypeVar, cast, runtime_checkable import jax -from jaxtyping import Array, Bool, Shaped - -import quaxed.array_api as xp +from jaxtyping import Array, Bool T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) -@partial(jax.jit, static_argnames="axis") -def interleave_concat( - a: Shaped[Array, "..."], b: Shaped[Array, "..."], /, axis: int -) -> Shaped[Array, "..."]: - a_shp = a.shape - return xp.stack((a, b), axis=axis + 1).reshape( - *a_shp[:axis], 2 * a_shp[axis], *a_shp[axis + 1 :] - ) - - -# ------------------------------------------------------------------- - - @runtime_checkable class SupportsGetItem(Protocol[T_co]): """Protocol for types that support the `__getitem__` method.""" diff --git a/tests/smoke/coordinates/test_package.py b/tests/smoke/coordinates/test_package.py index f1946728..4d29af03 100644 --- a/tests/smoke/coordinates/test_package.py +++ b/tests/smoke/coordinates/test_package.py @@ -1,7 +1,7 @@ """Testing :mod:`galax.dynamics` module.""" import galax.coordinates as gc -from galax.coordinates._psp import base, core, interp, utils +from galax.coordinates._psp import base, base_composite, base_psp, core, interp, utils def test_all() -> None: @@ -9,6 +9,8 @@ def test_all() -> None: assert set(gc.__all__) == { "operators", *base.__all__, + *base_psp.__all__, + *base_composite.__all__, *core.__all__, *interp.__all__, *utils.__all__, diff --git a/tests/smoke/dynamics/test_package.py b/tests/smoke/dynamics/test_package.py index 662ee888..cf56d47c 100644 --- a/tests/smoke/dynamics/test_package.py +++ b/tests/smoke/dynamics/test_package.py @@ -14,6 +14,7 @@ def test_all() -> None: "Orbit", "InterpolatedOrbit", "evaluate_orbit", + "MockStreamArm", "MockStream", "MockStreamGenerator", "AbstractStreamDF", diff --git a/tests/unit/coordinates/psp/test_base.py b/tests/unit/coordinates/psp/test_base_psp.py similarity index 100% rename from tests/unit/coordinates/psp/test_base.py rename to tests/unit/coordinates/psp/test_base_psp.py diff --git a/tests/unit/coordinates/psp/test_psp.py b/tests/unit/coordinates/psp/test_psp.py index 792b60f4..af778d3e 100644 --- a/tests/unit/coordinates/psp/test_psp.py +++ b/tests/unit/coordinates/psp/test_psp.py @@ -2,7 +2,7 @@ import pytest -from .test_base import AbstractPhaseSpacePosition_Test +from .test_base_psp import AbstractPhaseSpacePosition_Test from galax.coordinates import PhaseSpacePosition diff --git a/tests/unit/dynamics/test_orbit.py b/tests/unit/dynamics/test_orbit.py index acb7e1ce..4168ad94 100644 --- a/tests/unit/dynamics/test_orbit.py +++ b/tests/unit/dynamics/test_orbit.py @@ -12,7 +12,7 @@ from unxt.unitsystems import galactic import galax.typing as gt -from ..coordinates.psp.test_base import AbstractPhaseSpacePosition_Test, return_keys +from ..coordinates.psp.test_base_psp import AbstractPhaseSpacePosition_Test, return_keys from galax.coordinates import PhaseSpacePosition from galax.dynamics import Orbit from galax.potential import AbstractPotentialBase, MilkyWayPotential