Skip to content

Commit

Permalink
feat: CompositePSP and cleanup (GalacticDynamics#343)
Browse files Browse the repository at this point in the history
* fix: mockstream q interleaving
* feat: CompositePhaseSpacePosition
* feat: df sample returns CompositePSP
* refactor: cleanup

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jun 11, 2024
1 parent a153b15 commit 4a7fb9d
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 83 deletions.
2 changes: 2 additions & 0 deletions src/galax/coordinates/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __all__ = [
"AbstractPhaseSpacePosition",
"AbstractCompositePhaseSpacePosition",
"PhaseSpacePosition",
"CompositePhaseSpacePosition",
"InterpolatedPhaseSpacePosition",
"PhaseSpacePositionInterpolant",
"ComponentShapeTuple",
Expand All @@ -20,6 +21,7 @@ from ._psp import (
AbstractCompositePhaseSpacePosition,
AbstractPhaseSpacePosition,
ComponentShapeTuple,
CompositePhaseSpacePosition,
InterpolatedPhaseSpacePosition,
PhaseSpacePosition,
PhaseSpacePositionInterpolant,
Expand Down
29 changes: 28 additions & 1 deletion src/galax/coordinates/_psp/base_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,39 @@ class of the target position class is used.
Examples
--------
TODO
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import galax.coordinates as gc
We define a composite phase-space position with two components.
Every component is a phase-space position in Cartesian coordinates.
>>> psp1 = gc.PhaseSpacePosition(q=Quantity([1, 2, 3], "m"),
... p=Quantity([4, 5, 6], "m/s"),
... t=Quantity(7.0, "s"))
>>> psp2 = gc.PhaseSpacePosition(q=Quantity([1.5, 2.5, 3.5], "m"),
... p=Quantity([4.5, 5.5, 6.5], "m/s"),
... t=Quantity(6.0, "s"))
>>> cpsp = gc.CompositePhaseSpacePosition(psp1=psp1, psp2=psp2)
We can transform the composite phase-space position to a new position class.
>>> cx.represent_as(cpsp, cx.CylindricalPosition)
CompositePhaseSpacePosition({'psp1': PhaseSpacePosition(
q=CylindricalPosition( ... ),
p=CylindricalVelocity( ... ),
t=Quantity...
),
'psp2': PhaseSpacePosition(
q=CylindricalPosition( ... ),
p=CylindricalVelocity( ... ),
t=...
)})
"""
differential_cls = (
position_cls.differential_cls if differential is None else differential
)
# TODO: can we use `replace`?
return type(psp)(
**{k: represent_as(v, position_cls, differential_cls) for k, v in psp.items()}
)
194 changes: 172 additions & 22 deletions src/galax/coordinates/_psp/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""galax: Galactic Dynamics in Jax."""

__all__ = ["PhaseSpacePosition"]
__all__ = ["PhaseSpacePosition", "CompositePhaseSpacePosition"]

from collections.abc import Iterable
from typing import Any, NamedTuple, final

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, Int, PyTree, Shaped
from typing_extensions import override

from coordinax import AbstractPosition3D, AbstractVelocity3D
import coordinax as cx
import quaxed.array_api as xp
import quaxed.numpy as jnp
from unxt import Quantity

import galax.typing as gt
from .base_composite import AbstractCompositePhaseSpacePosition
from .base_psp import AbstractPhaseSpacePosition
from .utils import _p_converter, _q_converter
from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape
Expand Down Expand Up @@ -46,18 +51,18 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
Parameters
----------
q : :class:`~vector.AbstractPosition3D`
q : :class:`~coordinax.AbstractPosition3D`
A 3-vector of the positions, allowing for batched inputs. This
parameter accepts any 3-vector, e.g. :class:`~vector.SphericalPosition`,
parameter accepts any 3-vector, e.g. :class:`~coordinax.SphericalPosition`,
or any input that can be used to make a
:class:`~vector.CartesianPosition3D` via
:meth:`vector.AbstractPosition3D.constructor`.
p : :class:`~vector.AbstractVelocity3D`
:class:`~coordinax.CartesianPosition3D` via
:meth:`coordinax.AbstractPosition3D.constructor`.
p : :class:`~coordinax.AbstractVelocity3D`
A 3-vector of the conjugate specific momenta at positions ``q``,
allowing for batched inputs. This parameter accepts any 3-vector
differential, e.g. :class:`~vector.SphericalVelocity`, or any input
that can be used to make a :class:`~vector.CartesianVelocity3D` via
:meth:`vector.CartesianVelocity3D.constructor`.
differential, e.g. :class:`~coordinax.SphericalVelocity`, or any input
that can be used to make a :class:`~coordinax.CartesianVelocity3D` via
:meth:`coordinax.CartesianVelocity3D.constructor`.
t : Quantity[float, (*batch,), 'time'] | None
The time corresponding to the positions.
Expand All @@ -70,18 +75,18 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
We assume the following imports:
>>> from unxt import Quantity
>>> from coordinax import CartesianPosition3D, CartesianVelocity3D
>>> from galax.coordinates import PhaseSpacePosition
>>> import coordinax as cx
>>> import galax.coordinates as gc
We can create a phase-space position:
>>> q = CartesianPosition3D(x=Quantity(1, "m"), y=Quantity(2, "m"),
... z=Quantity(3, "m"))
>>> p = CartesianVelocity3D(d_x=Quantity(4, "m/s"), d_y=Quantity(5, "m/s"),
... d_z=Quantity(6, "m/s"))
>>> q = cx.CartesianPosition3D(x=Quantity(1, "m"), y=Quantity(2, "m"),
... z=Quantity(3, "m"))
>>> p = cx.CartesianVelocity3D(d_x=Quantity(4, "m/s"), d_y=Quantity(5, "m/s"),
... d_z=Quantity(6, "m/s"))
>>> t = Quantity(7.0, "s")
>>> psp = PhaseSpacePosition(q=q, p=p, t=t)
>>> psp = gc.PhaseSpacePosition(q=q, p=p, t=t)
>>> psp
PhaseSpacePosition(
q=CartesianPosition3D(
Expand All @@ -99,8 +104,8 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
Note that both `q` and `p` have convenience converters, allowing them to
accept a variety of inputs when constructing a
:class:`~vector.CartesianPosition3D` or
:class:`~vector.CartesianVelocity3D`, respectively. For example,
:class:`~coordinax.CartesianPosition3D` or
:class:`~coordinax.CartesianVelocity3D`, respectively. For example,
>>> psp2 = PhaseSpacePosition(q=Quantity([1, 2, 3], "m"),
... p=Quantity([4, 5, 6], "m/s"), t=t)
Expand All @@ -109,13 +114,13 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
"""

q: AbstractPosition3D = eqx.field(converter=_q_converter)
q: cx.AbstractPosition3D = eqx.field(converter=_q_converter)
"""Positions, e.g CartesianPosition3D.
This is a 3-vector with a batch shape allowing for vector inputs.
"""

p: AbstractVelocity3D = eqx.field(converter=_p_converter)
p: cx.AbstractVelocity3D = eqx.field(converter=_p_converter)
r"""Conjugate momenta, e.g. CartesianVelocity3D.
This is a 3-vector with a batch shape allowing for vector inputs.
Expand Down Expand Up @@ -164,3 +169,148 @@ def wt(self, *, units: Any) -> gt.BatchVec7:
self.t, self.t is None, "No time defined for phase-space position"
)
return super().wt(units=units)


##############################################################################


def _concat(values: Iterable[PyTree], time_sorter: Int[Array, "..."]) -> PyTree:
return jtu.tree_map(
lambda *xs: xp.concat(tuple(jnp.atleast_1d(x) for x in xs), axis=-1)[
..., time_sorter
],
*values,
)


@final
class CompositePhaseSpacePosition(AbstractCompositePhaseSpacePosition):
r"""Composite Phase-Space Position with time.
The phase-space position is a point in the 7-dimensional phase space
:math:`\mathbb{R}^7` of a dynamical system. It is composed of the position
:math:`\boldsymbol{q}`, the time :math:`t`, and the conjugate momentum
:math:`\boldsymbol{p}`.
This class has the same constructor semantics as `dict`.
Parameters
----------
psps: dict | tuple, optional positional-only
initialize from a (key, value) mapping or tuple.
**kwargs : AbstractPhaseSpacePosition
The name=value pairs of the phase-space positions.
Notes
-----
- `q`, `p`, and `t` are a concatenation of all the constituent phase-space
positions, sorted by `t`.
- The batch shape of `q`, `p`, and `t` are broadcast together.
Examples
--------
We assume the following imports:
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import galax.coordinates as gc
We can create a phase-space position. Here we will use the convenience
constructors for Cartesian positions and velocities. To see the full
constructor, see :class:`~galax.coordinates.PhaseSpacePosition`.
>>> psp1 = gc.PhaseSpacePosition(q=Quantity([1, 2, 3], "m"),
... p=Quantity([4, 5, 6], "m/s"),
... t=Quantity(7.0, "s"))
>>> psp2 = gc.PhaseSpacePosition(q=Quantity([1.5, 2.5, 3.5], "m"),
... p=Quantity([4.5, 5.5, 6.5], "m/s"),
... t=Quantity(6.0, "s"))
We can create a composite phase-space position from these two phase-space
positions:
>>> cpsp = gc.CompositePhaseSpacePosition(psp1=psp1, psp2=psp2)
>>> cpsp
CompositePhaseSpacePosition({'psp1': PhaseSpacePosition(
q=CartesianPosition3D( ... ),
p=CartesianVelocity3D( ... ),
t=Quantity...
),
'psp2': PhaseSpacePosition(
q=CartesianPosition3D( ... ),
p=CartesianVelocity3D( ... ),
t=Quantity...
)})
The individual phase-space positions can be accessed via the keys:
>>> cpsp["psp1"]
PhaseSpacePosition(
q=CartesianPosition3D( ... ),
p=CartesianVelocity3D( ... ),
t=Quantity...
)
The ``q``, ``p``, and ``t`` attributes are the concatenation of the
constituent phase-space positions, sorted by ``t``. Note that in this
example, the time of ``psp2`` is earlier than ``psp1``.
>>> cpsp.t
Quantity['time'](Array([6., 7.], dtype=float64), unit='s')
>>> cpsp.q.x
Quantity['length'](Array([1.5, 1. ], dtype=float64), unit='m')
>>> cpsp.p.d_x
Quantity['speed'](Array([4.5, 4. ], dtype=float64), unit='m / s')
We can transform the composite phase-space position to a new position class.
>>> cx.represent_as(cpsp, cx.CylindricalPosition)
CompositePhaseSpacePosition({'psp1': PhaseSpacePosition(
q=CylindricalPosition( ... ),
p=CylindricalVelocity( ... ),
t=Quantity...
),
'psp2': PhaseSpacePosition(
q=CylindricalPosition( ... ),
p=CylindricalVelocity( ... ),
t=...
)})
"""

_time_sorter: Shaped[Array, "alltimes"]

def __init__(
self,
psps: dict[str, AbstractPhaseSpacePosition]
| tuple[tuple[str, AbstractPhaseSpacePosition], ...] = (),
/,
**kwargs: AbstractPhaseSpacePosition,
) -> None:
super().__init__(psps, **kwargs)

# TODO: check up on the shapes

# Construct time sorter
ts = xp.concat([jnp.atleast_1d(psp.t) for psp in self.values()], axis=0)
self._time_sorter = xp.argsort(ts)

@property
def q(self) -> cx.AbstractPosition3D:
"""Positions."""
# TODO: get AbstractPosition to work with `stack` directly
return _concat((x.q for x in self.values()), self._time_sorter)

@property
def p(self) -> cx.AbstractVelocity3D:
"""Conjugate momenta."""
# TODO: get AbstractPosition to work with `stack` directly
return _concat((x.p for x in self.values()), self._time_sorter)

@property
def t(self) -> Shaped[Quantity["time"], "..."]:
"""Times."""
return xp.concat([jnp.atleast_1d(psp.t) for psp in self.values()], axis=0)[
self._time_sorter
]
33 changes: 1 addition & 32 deletions src/galax/coordinates/_psp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,17 @@

__all__: list[str] = []

from collections.abc import Sequence
from functools import partial, singledispatch
from functools import singledispatch
from typing import Any, Protocol, cast, runtime_checkable

import astropy.coordinates as apyc
import jax
from jaxtyping import Array, Shaped

import coordinax as cx
import quaxed.array_api as xp
from unxt import Quantity

import galax.typing as gt


@partial(jax.jit, static_argnames="axis")
def interleave_concat(
arrays: Sequence[Shaped[Array, "shape"]] | Sequence[Shaped[Quantity, "shape"]],
/,
axis: int,
) -> Shaped[Array, "..."] | Shaped[Quantity, "..."]: # TODO: shape hint
# Check if input is a non-empty list
if not arrays or not isinstance(arrays, Sequence):
msg = "Input should be a non-empty sequence of arrays."
raise ValueError(msg)

# Ensure all arrays have the same shape
shape0 = arrays[0].shape
if not all(arr.shape == shape0 for arr in arrays):
msg = "All arrays must have the same shape."
raise ValueError(msg)

# Stack the arrays along a new axis to prepare for interleaving
axis = axis % len(shape0) # allows for negative axis
stacked = xp.stack(arrays, axis=axis + 1)

# Flatten the new axis by interleaving values
return xp.reshape(
stacked, (*shape0[:axis], len(arrays) * shape0[axis], *shape0[axis + 1 :])
)


@runtime_checkable
class HasShape(Protocol):
"""Protocol for an object with a shape attribute."""
Expand Down
5 changes: 2 additions & 3 deletions src/galax/dynamics/_dynamics/mockstream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
_p_converter,
_q_converter,
getitem_vec1time_index,
interleave_concat,
)
from galax.utils._shape import batched_shape, vector_batched_shape

Expand Down Expand Up @@ -111,10 +110,10 @@ def __init__(
@property
def q(self) -> cx.AbstractPosition3D:
"""Positions."""
# TODO: interleave by time
# TODO: get AbstractPosition to work with `stack` directly
return jtu.tree_map(
lambda *x: interleave_concat(x, axis=-1), *(x.q for x in self.values())
lambda *x: xp.concat(x, axis=-1)[..., self._time_sorter],
*(x.q for x in self.values()),
)

@property
Expand Down
Loading

0 comments on commit 4a7fb9d

Please sign in to comment.