Skip to content

Commit

Permalink
feat(psp): Composite PSP (GalacticDynamics#301)
Browse files Browse the repository at this point in the history
* feat: composite MockStream class

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed May 21, 2024
1 parent aca045a commit d305016
Show file tree
Hide file tree
Showing 22 changed files with 513 additions and 198 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ filterwarnings = [
"error",
"ignore:ast\\.Str is deprecated:DeprecationWarning",
"ignore:numpy\\.ndarray size changed:RuntimeWarning",
"ignore:Passing arguments 'a':DeprecationWarning", # upstream: diffrax
"ignore:Passing arguments 'a':DeprecationWarning", # TODO: from diffrax
]
log_cli_level = "INFO"
markers = [
Expand Down
4 changes: 4 additions & 0 deletions src/galax/coordinates/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ __all__ = [
# Modules
"operators",
# Phase-space positions
"AbstractBasePhaseSpacePosition",
"AbstractPhaseSpacePosition",
"AbstractCompositePhaseSpacePosition",
"PhaseSpacePosition",
"InterpolatedPhaseSpacePosition",
"PhaseSpacePositionInterpolant",
Expand All @@ -14,6 +16,8 @@ __all__ = [

from . import operators
from ._psp import (
AbstractBasePhaseSpacePosition,
AbstractCompositePhaseSpacePosition,
AbstractPhaseSpacePosition,
ComponentShapeTuple,
InterpolatedPhaseSpacePosition,
Expand Down
6 changes: 6 additions & 0 deletions src/galax/coordinates/_psp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@

from . import (
base,
base_composite,
base_psp,
compat_apy, # noqa: F401
core,
interp,
operator_compat, # noqa: F401
utils,
)
from .base import *
from .base_composite import *
from .base_psp import *
from .core import *
from .interp import *
from .utils import *

__all__: list[str] = []
__all__ += base.__all__
__all__ += base_psp.__all__
__all__ += base_composite.__all__
__all__ += core.__all__
__all__ += interp.__all__
__all__ += utils.__all__
128 changes: 23 additions & 105 deletions src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""galax: Galactic Dynamics in Jax."""
"""ABC for phase-space positions."""

__all__ = ["AbstractPhaseSpacePosition", "ComponentShapeTuple"]
__all__ = ["AbstractBasePhaseSpacePosition", "ComponentShapeTuple"]

from abc import abstractmethod
from collections.abc import Mapping
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Any, NamedTuple, cast

Expand All @@ -18,14 +17,11 @@
from unxt import Quantity, unitsystem

import galax.typing as gt
from .utils import getitem_broadscalartime_index
from galax.utils.dataclasses import dataclass_items

if TYPE_CHECKING:
from typing import Self

from galax.potential import AbstractPotentialBase


class ComponentShapeTuple(NamedTuple):
"""Component shape of the phase-space position."""
Expand All @@ -40,22 +36,14 @@ class ComponentShapeTuple(NamedTuple):
"""Shape of the time component."""


class AbstractPhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-arg, misc]
r"""Abstract base class of phase-space positions.
class AbstractBasePhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-arg, misc]
r"""ABC underlying phase-space positions and their composites.
The phase-space position is a point in the 7-dimensional phase space
The phase-space position is a point in the 3+3+1-dimensional phase space
:math:`\mathbb{R}^7` of a dynamical system. It is composed of the position
:math:`\boldsymbol{q}`, the conjugate momentum :math:`\boldsymbol{p}`, and
the time :math:`t`.
Parameters
----------
q : :class:`~vector.Abstract3DVector`
Positions.
p : :class:`~vector.Abstract3DVectorDifferential`
Conjugate momenta at positions ``q``.
t : :class:`~unxt.Quantity`
Time corresponding to the positions and momenta.
:math:`\boldsymbol{q}\in\mathbb{R}^3`, the conjugate momentum
:math:`\boldsymbol{p}\in\mathbb{R}^3`, and the time
:math:`t\in\mathbb{R}^1`.
"""

q: eqx.AbstractVar[cx.Abstract3DVector]
Expand All @@ -73,20 +61,20 @@ class AbstractPhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-
@classmethod
@dispatch # type: ignore[misc]
def constructor(
cls: "type[AbstractPhaseSpacePosition]", obj: Mapping[str, Any], /
) -> "AbstractPhaseSpacePosition":
cls: "type[AbstractBasePhaseSpacePosition]", obj: Mapping[str, Any], /
) -> "AbstractBasePhaseSpacePosition":
"""Construct from a mapping.
Parameters
----------
cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`]
cls : type[:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`]
The class to construct.
obj : Mapping[str, Any]
The mapping from which to construct.
Returns
-------
:class:`~galax.coordinates.AbstractPhaseSpacePosition`
:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`
The constructed phase-space position.
Examples
Expand Down Expand Up @@ -173,12 +161,10 @@ def __len__(self) -> int:
"""Return the number of particles."""
return self.shape[0]

@abstractmethod
def __getitem__(self, index: Any) -> "Self":
"""Return a new object with the given slice applied."""
# Compute subindex
subindex = getitem_broadscalartime_index(index, self.t)
# Apply slice
return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex])
...

# ==========================================================================
# Further Array properties
Expand Down Expand Up @@ -348,6 +334,7 @@ class of the target position class is used.
""" # noqa: E501
return cast("Self", cx.represent_as(self, position_cls, differential_cls))

@abstractmethod
def to_units(self, units: Any) -> "Self":
"""Return with the components transformed to the given unit system.
Expand Down Expand Up @@ -385,13 +372,7 @@ def to_units(self, units: Any) -> "Self":
t=Quantity[...](value=f64[], unit=Unit("yr"))
)
"""
usys = unitsystem(units)
return replace(
self,
q=self.q.to_units(usys),
p=self.p.to_units(usys),
t=self.t.to_units(usys["time"]) if self.t is not None else None,
)
...

# ==========================================================================
# Dynamical quantities
Expand Down Expand Up @@ -584,22 +565,22 @@ def angular_momentum(self) -> gt.BatchQVec3:
# Register additional constructors


@AbstractPhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001
@AbstractBasePhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001
def constructor(
cls: type[AbstractPhaseSpacePosition], obj: AbstractPhaseSpacePosition, /
) -> AbstractPhaseSpacePosition:
"""Construct from a `AbstractPhaseSpacePosition`.
cls: type[AbstractBasePhaseSpacePosition], obj: AbstractBasePhaseSpacePosition, /
) -> AbstractBasePhaseSpacePosition:
"""Construct from a `AbstractBasePhaseSpacePosition`.
Parameters
----------
cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`]
cls : type[:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`]
The class to construct.
obj : :class:`~galax.coordinates.AbstractPhaseSpacePosition`
obj : :class:`~galax.coordinates.AbstractBasePhaseSpacePosition`
The phase-space position object from which to construct.
Returns
-------
:class:`~galax.coordinates.AbstractPhaseSpacePosition`
:class:`~galax.coordinates.AbstractBasePhaseSpacePosition`
The constructed phase-space position.
Raises
Expand Down Expand Up @@ -644,66 +625,3 @@ def constructor(
return obj

return cls(**dict(dataclass_items(obj)))


# -----------------------------------------------
# Register AbstractPhaseSpacePosition with `coordinax.represent_as`
@dispatch # type: ignore[misc]
def represent_as(
psp: AbstractPhaseSpacePosition,
position_cls: type[cx.AbstractVectorBase],
/,
differential: type[cx.AbstractVectorDifferential] | None = None,
) -> AbstractPhaseSpacePosition:
"""Return with the components transformed.
Parameters
----------
psp : :class:`~galax.coordinates.AbstractPhaseSpacePosition`
The phase-space position.
position_cls : type[:class:`~vector.AbstractVectorBase`]
The target position class.
differential : type[:class:`~vector.AbstractVectorDifferential`], optional
The target differential class. If `None` (default), the differential
class of the target position class is used.
Examples
--------
With the following imports:
>>> from unxt import Quantity
>>> import coordinax as cx
>>> from galax.coordinates import PhaseSpacePosition
We can create a phase-space position and convert it to a 6-vector:
>>> psp = PhaseSpacePosition(q=Quantity([1, 2, 3], "kpc"),
... p=Quantity([4, 5, 6], "km/s"),
... t=Quantity(0, "Gyr"))
>>> psp.w(units="galactic")
Array([1. , 2. , 3. , 0.00409085, 0.00511356, 0.00613627], dtype=float64)
We can also convert it to a different representation:
>>> psp.represent_as(cx.CylindricalVector)
PhaseSpacePosition( q=CylindricalVector(...),
p=CylindricalDifferential(...),
t=Quantity[...](value=f64[], unit=Unit("Gyr")) )
We can also convert it to a different representation with a different
differential class:
>>> psp.represent_as(cx.LonLatSphericalVector, cx.LonCosLatSphericalDifferential)
PhaseSpacePosition( q=LonLatSphericalVector(...),
p=LonCosLatSphericalDifferential(...),
t=Quantity[...](value=f64[], unit=Unit("Gyr")) )
"""
differential_cls = (
position_cls.differential_cls if differential is None else differential
)
return replace(
psp,
q=psp.q.represent_as(position_cls),
p=psp.p.represent_as(differential_cls, psp.q),
)
Loading

0 comments on commit d305016

Please sign in to comment.