Skip to content

Commit

Permalink
refactor: consolidate integrate (GalacticDynamics#346)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jun 12, 2024
1 parent 1cf28c4 commit 191595d
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 282 deletions.
4 changes: 4 additions & 0 deletions src/galax/coordinates/_psp/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class PhaseSpacePositionInterpolant(Protocol):
units: AbstractUnitSystem
"""The unit system for the interpolation."""

added_ndim: int
"""The number of dimensions added to the input time."""
# TODO: not require this for Diffrax

def __call__(self, t: gt.QVecTime) -> PhaseSpacePosition:
"""Evaluate the interpolation.
Expand Down
109 changes: 104 additions & 5 deletions src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
__all__ = ["AbstractIntegrator"]

import abc
from typing import Literal
from typing import Any, Literal, TypeVar

import equinox as eqx
from plum import overload

from unxt import AbstractUnitSystem
import quaxed.array_api as xp
from unxt import AbstractUnitSystem, Quantity, to_units_value

import galax.coordinates as gc
import galax.typing as gt
from ._api import VectorField
from ._api import SaveT, VectorField

Interp = TypeVar("Interp")


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand All @@ -24,8 +28,66 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis
motion. They must not be stateful since they are used in a functional way.
"""

# TODO: shape hint of the return type
InterpolantClass: eqx.AbstractClassVar[type[gc.PhaseSpacePositionInterpolant]]

@abc.abstractmethod
def _call_implementation(
self,
F: VectorField,
w0: gt.BatchVec6,
t0: gt.FloatScalar,
t1: gt.FloatScalar,
ts: gt.BatchVecTime,
/,
interpolated: Literal[False, True],
) -> tuple[gt.BatchVecTime7, Any | None]: # TODO: type hint Interpolant
"""Integrator implementation."""
...

def _process_interpolation(
self, interp: Interp, w0: gt.BatchVec6
) -> tuple[Interp, int]:
"""Process the interpolation.
This is the default implementation and will probably need to be
overridden in a subclass.
"""
# Determine if an extra dimension was added to the output
added_ndim = int(w0.shape[:-1] in ((), (1,)))
# Return the interpolation and the number of added dimensions
return interp, added_ndim

# ------------------------------------------------------------------------

@overload
def __call__(
self,
F: VectorField,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
saveat: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[False] = False,
) -> gc.PhaseSpacePosition: ...

@overload
def __call__(
self,
F: VectorField,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
saveat: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[True],
) -> gc.InterpolatedPhaseSpacePosition: ...

# TODO: shape hint of the return type
def __call__(
self,
F: VectorField,
Expand Down Expand Up @@ -131,4 +193,41 @@ def __call__(
(2,)
"""
...
# Parse inputs
w0_: gt.Vec6 = (
w0.w(units=units) if isinstance(w0, gc.AbstractPhaseSpacePosition) else w0
)
t0_: gt.VecTime = to_units_value(t0, units["time"])
t1_: gt.VecTime = to_units_value(t1, units["time"])
# Either save at `saveat` or at the final time. The final time is
# a scalar and the saveat is a vector, so a dimension is added.
saveat_ = (
xp.asarray([t1_])
if saveat is None
else to_units_value(saveat, units["time"])
)

# Perform the integration
w, interp = self._call_implementation(F, w0_, t0_, t1_, saveat_, interpolated)
w = w[..., -1, :] if saveat is None else w # get rid of added dimension

# Return
if interpolated:
interp, added_ndim = self._process_interpolation(interp, w0_)

out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(saveat_, units["time"]),
interpolant=self.InterpolantClass(
interp, units=units, added_ndim=added_ndim
),
)
else:
out = gc.PhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
)

return out
Loading

0 comments on commit 191595d

Please sign in to comment.