Skip to content

Commit

Permalink
feat: interpolated integration (GalacticDynamics#212)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Apr 3, 2024
1 parent 0b2cc67 commit 3bb912d
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 26 deletions.
5 changes: 3 additions & 2 deletions src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, TypeAlias, runtime_checkable
from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable

from unxt import AbstractUnitSystem

Expand Down Expand Up @@ -59,7 +59,8 @@ def __call__(
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Integrate.
Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["AbstractIntegrator"]

import abc
from typing import Literal

import equinox as eqx

Expand Down Expand Up @@ -37,7 +38,8 @@ def __call__(
) = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.
Parameters
Expand Down
169 changes: 156 additions & 13 deletions src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
from collections.abc import Callable, Mapping
from dataclasses import KW_ONLY
from functools import partial
from typing import Any, ParamSpec, TypeVar, final
from typing import Any, Literal, ParamSpec, TypeVar, final

import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
from diffrax import DenseInterpolation
from jax._src.numpy.vectorize import _parse_gufunc_signature, _parse_input_dimensions
from plum import overload

import quaxed.array_api as xp
from unxt import AbstractUnitSystem, Quantity, to_units_value
from unxt import AbstractUnitSystem, Quantity, to_units_value, unitsystem

import galax.coordinates as gc
import galax.typing as gt
Expand Down Expand Up @@ -117,7 +119,10 @@ class DiffraxIntegrator(AbstractIntegrator):
default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict
)

@partial(jax.jit, static_argnums=(0, 1))
# =====================================================
# Call

@partial(eqx.filter_jit)
def _call_implementation(
self,
F: FCallable,
Expand All @@ -126,9 +131,12 @@ def _call_implementation(
t1: gt.FloatScalar,
ts: gt.BatchVecTime,
/,
) -> tuple[gt.BatchVecTime7, None]:
interpolated: Literal[False, True],
) -> tuple[gt.BatchVecTime7, DenseInterpolation | None]:
# TODO: less awkward munging of the diffrax API
kw = dict(self.diffeq_kw)
if interpolated and kw.get("max_steps") is None:
kw.pop("max_steps")

terms = diffrax.ODETerm(F)
solver = self.Solver(**self.solver_kw)
Expand All @@ -146,13 +154,13 @@ def solve_diffeq(
y0=w0,
dt0=None,
args=(),
saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=False),
saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=interpolated),
stepsize_controller=self.stepsize_controller,
**kw,
)

# Perform the integration
solution = solve_diffeq(w0, t0, t1, ts)
solution = solve_diffeq(w0, t0, t1, jnp.atleast_2d(ts))

# Parse the solution
w = jnp.concat((solution.ys, solution.ts[..., None]), axis=-1)
Expand All @@ -161,6 +169,21 @@ def solve_diffeq(

return w, interp

@overload
def __call__(
self,
F: FCallable,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[False] = False,
) -> gc.PhaseSpacePosition: ...

@overload
def __call__(
self,
F: FCallable,
Expand All @@ -171,7 +194,21 @@ def __call__(
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[True],
) -> gc.InterpolatedPhaseSpacePosition: ...

def __call__(
self,
F: FCallable,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.
Parameters
Expand All @@ -189,6 +226,8 @@ def __call__(
units : `unxt.AbstractUnitSystem`
The unit system to use.
interpolated : bool, keyword-only
Whether to return an interpolated solution.
Returns
-------
Expand Down Expand Up @@ -261,6 +300,49 @@ def __call__(
>>> ws.shape
(2,)
A cool feature of the integrator is that it can return an interpolated
solution.
>>> w = integrator(pot._integrator_F, w0, t0, t1, savet=ts, units=usx.galactic,
... interpolated=True)
>>> type(w)
<class 'galax.coordinates...InterpolatedPhaseSpacePosition'>
The interpolated solution can be evaluated at any time in the domain to get
the phase-space position at that time:
>>> t = Quantity(xp.e, "Gyr")
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,1], unit=Unit("Gyr"))
)
The interpolant is vectorized:
>>> t = Quantity(xp.linspace(0, 1, 100), "Gyr")
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr"))
)
And it works on batches:
>>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"),
... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s"))
>>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic,
... interpolated=True)
>>> ws.shape
(2,)
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr"))
)
"""
# Parse inputs
w0_: gt.Vec6 = (
Expand All @@ -273,12 +355,73 @@ def __call__(
)

# Perform the integration
w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_)
w = w[..., -1, :] if savet is None else w
w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_, interpolated)
w = w[..., -1, :] if savet is None else w # TODO: undo this

# Return
return gc.PhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
if interpolated:
# Determine if an extra dimension was added to the output
added_ndim = int(w0_.shape[:-1] == () or w0_.shape[0] == 1)
# If one was, then the interpolant must be reshaped since the input
# was squeezed beforehand and the dimension must be added back.
if added_ndim == 1:
arr, narr = eqx.partition(interp, eqx.is_array)
arr = jax.tree_util.tree_map(lambda x: x[None], arr)
interp = eqx.combine(arr, narr)

out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(savet_, units["time"]),
interpolant=DiffraxInterpolant(
interp, units=units, added_ndim=added_ndim
),
)
else:
out = gc.PhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
)

return out


class DiffraxInterpolant(eqx.Module): # type: ignore[misc]#
"""Wrapper for ``diffrax.DenseInterpolation``."""

interpolant: DenseInterpolation
""":class:`diffrax.DenseInterpolation` object.
This object is the result of the integration and can be used to evaluate the
interpolated solution at any time. However it does not understand units, so
the input is the time in ``units["time"]``. The output is a 6-vector of
(q, p) values in the units of the integrator.
"""

units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem)
"""The :class:`unxt.AbstractUnitSystem`.
This is used to convert the time input to the interpolant and the phase-space
position output.
"""

added_ndim: tuple[int, ...] = eqx.field(static=True)
"""The number of dimensions added to the output of the interpolation.
This is used to reshape the output of the interpolation to match the batch
shape of the input to the integrator. The means of vectorizing the
interpolation means that the input must always be a batched array, resulting
in an extra dimension when the integration was on a scalar input.
"""

def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition:
"""Evaluate the interpolation."""
t_ = jnp.atleast_1d(t.to_units_value(self.units["time"]))
ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolant)
ys = ys[(0,) * (ys.ndim - 3 + self.added_ndim)]
return gc.PhaseSpacePosition(
q=Quantity(ys[..., 0:3], self.units["length"]),
p=Quantity(ys[..., 3:6], self.units["speed"]),
t=t,
)
32 changes: 24 additions & 8 deletions src/galax/dynamics/_dynamics/integrate/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import replace
from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +17,7 @@
from ._api import Integrator
from ._builtin import DiffraxIntegrator
from galax.coordinates import PhaseSpacePosition
from galax.dynamics._dynamics.orbit import Orbit
from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit
from galax.potential._potential.base import AbstractPotentialBase

##############################################################################
Expand All @@ -29,20 +30,21 @@
_select_w0 = jnp.vectorize(jax.lax.select, signature="(),(6),(6)->(6)")


@partial(jax.jit, static_argnames=("integrator",))
@partial(jax.jit, static_argnames=("integrator", "interpolated"))
def evaluate_orbit(
pot: AbstractPotentialBase,
w0: PhaseSpacePosition | gt.BatchVec6,
t: gt.QVecTime | gt.VecTime | APYQuantity,
*,
integrator: Integrator | None = None,
) -> Orbit:
interpolated: Literal[True, False] = False,
) -> Orbit | InterpolatedOrbit:
"""Compute an orbit in a potential.
:class:`~galax.coordinates.PhaseSpacePosition` includes a time in
addition to the position (and velocity) information, enabling the orbit to
be evaluated over a time range that is different from the initial time of
the position.
:class:`~galax.coordinates.PhaseSpacePosition` includes a time in addition
to the position (and velocity) information, enabling the orbit to be
evaluated over a time range that is different from the initial time of the
position.
Parameters
----------
Expand Down Expand Up @@ -82,6 +84,10 @@ def evaluate_orbit(
is used twice: once to integrate from `w0.t` to `t[0]` and then from
`t[0]` to `t[1]`.
interpolated: bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.
Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
Expand Down Expand Up @@ -225,7 +231,17 @@ def evaluate_orbit(
t[-1],
savet=t,
units=units,
interpolated=interpolated,
)
wt = t

# Construct the orbit object
return Orbit(q=ws.q, p=ws.p, t=t, potential=pot)
# TODO: easier construction from the (Interpolated)PhaseSpacePosition
if interpolated:
out = InterpolatedOrbit(
q=ws.q, p=ws.p, t=wt, interpolant=ws.interpolant, potential=pot
)
else:
out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot)

return out
15 changes: 13 additions & 2 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import KW_ONLY, fields
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, cast

import equinox as eqx
import jax
Expand Down Expand Up @@ -1896,6 +1896,7 @@ def evaluate_orbit(
t: gt.QVecTime | gt.VecTime | APYQuantity, # TODO: must be a Quantity
*,
integrator: "Integrator | None" = None,
interpolated: Literal[True, False] = False,
) -> "Orbit":
"""Compute an orbit in a potential.
Expand Down Expand Up @@ -1943,6 +1944,11 @@ def evaluate_orbit(
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.
interpolated: bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.
Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
Expand All @@ -1956,4 +1962,9 @@ def evaluate_orbit(
"""
from galax.dynamics import evaluate_orbit

return cast("Orbit", evaluate_orbit(self, w0, t, integrator=integrator))
return cast(
"Orbit",
evaluate_orbit(
self, w0, t, integrator=integrator, interpolated=interpolated
),
)

0 comments on commit 3bb912d

Please sign in to comment.